Skip to content

Commit 570bad1

Browse files
committed
type checker tests too
1 parent 94cd39a commit 570bad1

File tree

5 files changed

+119
-20
lines changed

5 files changed

+119
-20
lines changed

src/finchlite/codegen/c.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@
66
import tempfile
77
from abc import ABC, abstractmethod
88
from collections import namedtuple
9-
from collections.abc import Callable
9+
from collections.abc import Callable, Hashable
1010
from functools import lru_cache
1111
from pathlib import Path
1212
from types import NoneType
1313
from typing import Any
1414

1515
import numpy as np
1616

17-
from finchlite.finch_assembly.map import FType, MapFType
18-
1917
from .. import finch_assembly as asm
2018
from ..algebra import query_property, register_property
21-
from ..finch_assembly import AssemblyStructFType, BufferFType, TupleFType
19+
from ..finch_assembly import AssemblyStructFType, BufferFType, MapFType, TupleFType
2220
from ..symbolic import Context, Namespace, ScopedDict, fisinstance, ftype
2321
from ..util import config
2422
from ..util.cache import file_cache
@@ -612,22 +610,22 @@ def __init__(
612610
self.fptr = fptr
613611
self.types = types
614612
self.slots = slots
615-
self.datastructures: dict[FType, Any] = {}
613+
self.datastructures: dict[Hashable, Any] = {}
616614

617615
def add_header(self, header):
618616
if header not in self._headerset:
619617
self.headers.append(header)
620618
self._headerset.add(header)
621619

622-
def add_datastructure(self, ftype: FType, handler: "Callable[[CContext], Any]"):
620+
def add_datastructure(self, key: Hashable, handler: "Callable[[CContext], Any]"):
623621
"""
624622
Code to add a datastructure declaration.
625623
This is the minimum required to prevent redundancy.
626624
"""
627-
if ftype in self.datastructures:
625+
if key in self.datastructures:
628626
return
629627
# at least mark something is there.
630-
self.datastructures[ftype] = None
628+
self.datastructures[key] = None
631629
handler(self)
632630

633631
def emit_global(self):

src/finchlite/codegen/hashtable.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def c_unpack(self, ctx: "CContext", var_n: str, val: AssemblyExpression):
390390
data = ctx.freshen(var_n, "data")
391391
# Add all the stupid header stuff from above.
392392
ctx.add_datastructure(
393-
self,
393+
("CHashTableFType", self.key_len, self.value_len),
394394
lambda ctx: CHashTable.gen_code(
395395
ctx, self.key_type, self.value_type, inline=True
396396
),
@@ -445,9 +445,6 @@ def __init__(self, key_len, value_len, map: "dict[tuple,tuple] | None" = None):
445445
self.key_len = key_len
446446
self.value_len = value_len
447447

448-
self._key_type = _int_tuple_ftype(key_len)
449-
self._value_type = _int_tuple_ftype(value_len)
450-
451448
self._numba_key_type = numba.types.UniTuple(numba.types.int64, key_len)
452449
self._numba_value_type = numba.types.UniTuple(numba.types.int64, value_len)
453450

src/finchlite/finch_assembly/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77
from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation
88
from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel
9+
from .map import MapFType
910
from .nodes import (
1011
AssemblyNode,
1112
Assert,
@@ -69,6 +70,7 @@
6970
"Literal",
7071
"Load",
7172
"LoadMap",
73+
"MapFType",
7274
"Module",
7375
"NamedTupleFType",
7476
"Print",

src/finchlite/finch_assembly/map.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,24 @@ def __init__(
1616
): ...
1717

1818
@property
19-
def element_type(self):
19+
@abstractmethod
20+
def ftype(self) -> "MapFType": ...
21+
22+
@property
23+
def value_type(self):
2024
"""
21-
Return the type of elements stored in the hash table.
22-
This is typically the same as the dtype used to create the map.
25+
Return type of values stored in the hash table
26+
(probably some TupleFType)
2327
"""
24-
return self.ftype.element_type()
28+
return self.ftype.value_type
2529

2630
@property
27-
def length_type(self):
31+
def key_type(self):
2832
"""
29-
Return the type of indices used to access elements in the hash table.
30-
This is typically an integer type.
33+
Return type of keys stored in the hash table
34+
(probably some TupleFType)
3135
"""
32-
return self.ftype.length_type()
36+
return self.ftype.key_type
3337

3438
@abstractmethod
3539
def load(self, idx: tuple):

tests/test_assembly_type_checker.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import finchlite.finch_assembly as asm
99
from finchlite.codegen import NumpyBuffer
10+
from finchlite.codegen.hashtable import CHashTable, NumbaHashTable
1011
from finchlite.finch_assembly import assembly_check_types
12+
from finchlite.finch_assembly.struct import TupleFType
1113
from finchlite.symbolic import FType, ftype
1214

1315

@@ -684,3 +686,99 @@ def test_simple_struct():
684686
)
685687

686688
assembly_check_types(mod)
689+
690+
691+
@pytest.mark.parametrize(
692+
["tabletype"],
693+
[
694+
(CHashTable,),
695+
(NumbaHashTable,),
696+
],
697+
)
698+
def test_hashtable(tabletype):
699+
table = tabletype(2, 3)
700+
701+
table_v = asm.Variable("a", ftype(table))
702+
table_slt = asm.Slot("a_", ftype(table))
703+
704+
key_type = table.ftype.key_type
705+
val_type = table.ftype.value_type
706+
key_v = asm.Variable("key", key_type)
707+
val_v = asm.Variable("val", val_type)
708+
709+
mod = asm.Module(
710+
(
711+
asm.Function(
712+
asm.Variable(
713+
"setidx", TupleFType.from_tuple(tuple(int for _ in range(3)))
714+
),
715+
(table_v, key_v, val_v),
716+
asm.Block(
717+
(
718+
asm.Unpack(table_slt, table_v),
719+
asm.StoreMap(
720+
table_slt,
721+
key_v,
722+
val_v,
723+
),
724+
asm.Repack(table_slt),
725+
asm.Return(asm.LoadMap(table_slt, key_v)),
726+
)
727+
),
728+
),
729+
asm.Function(
730+
asm.Variable("exists", bool),
731+
(table_v, key_v),
732+
asm.Block(
733+
(
734+
asm.Unpack(table_slt, table_v),
735+
asm.Return(asm.ExistsMap(table_slt, key_v)),
736+
)
737+
),
738+
),
739+
)
740+
)
741+
assembly_check_types(mod)
742+
743+
744+
@pytest.mark.parametrize(
745+
["tabletype"],
746+
[
747+
(CHashTable,),
748+
(NumbaHashTable,),
749+
],
750+
)
751+
def test_hashtable_fail(tabletype):
752+
table = tabletype(2, 3)
753+
754+
table_v = asm.Variable("a", ftype(table))
755+
table_slt = asm.Slot("a_", ftype(table))
756+
757+
key_type = table.ftype.key_type
758+
val_type = table.ftype.value_type
759+
key_v = asm.Variable("key", key_type)
760+
val_v = asm.Variable("val", val_type)
761+
mod = asm.Module(
762+
(
763+
asm.Function(
764+
asm.Variable(
765+
"setidx", TupleFType.from_tuple(tuple(int for _ in range(2)))
766+
),
767+
(table_v, key_v, val_v),
768+
asm.Block(
769+
(
770+
asm.Unpack(table_slt, table_v),
771+
asm.StoreMap(
772+
table_slt,
773+
key_v,
774+
val_v,
775+
),
776+
asm.Repack(table_slt),
777+
asm.Return(asm.LoadMap(table_slt, key_v)),
778+
)
779+
),
780+
),
781+
)
782+
)
783+
with pytest.raises(asm.AssemblyTypeError):
784+
assembly_check_types(mod)

0 commit comments

Comments
 (0)