diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 475290ac..ad60eb37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,8 @@ jobs: steps: - name: Checkout Repo uses: actions/checkout@v4 + with: + submodules: recursive - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} @@ -36,7 +38,7 @@ jobs: poetry install --extras test - name: Run tests run: | - poetry run pytest --junit-xml=test-${{ matrix.os }}-Python-${{ matrix.python }}.xml + poetry run pytest -s --junit-xml=test-${{ matrix.os }}-Python-${{ matrix.python }}.xml - uses: codecov/codecov-action@v5 on: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index dcf2d3f0..8d95ebca 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -7,6 +7,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + submodules: recursive - name: Set up Python uses: actions/setup-python@v2 with: diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..18159e86 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/finchlite/codegen/stc"] + path = src/finchlite/codegen/stc + url = https://github.com/stclib/stc diff --git a/pyproject.toml b/pyproject.toml index 6f597405..f090f75a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,4 +58,4 @@ section-order = [ [tool.mypy] ignore_missing_imports = true -exclude = ["tests/reference"] +exclude = ["tests/reference", "src/finchlite/codegen/stc"] diff --git a/src/finchlite/codegen/c.py b/src/finchlite/codegen/c.py index f9b72b6d..264e3b12 100644 --- a/src/finchlite/codegen/c.py +++ b/src/finchlite/codegen/c.py @@ -6,6 +6,7 @@ import tempfile from abc import ABC, abstractmethod from collections import namedtuple +from collections.abc import Callable, Hashable from functools import lru_cache from pathlib import Path from types import NoneType @@ -15,7 +16,14 @@ from .. import finch_assembly as asm from ..algebra import query_property, register_property -from ..finch_assembly import AssemblyStructFType, BufferFType, TupleFType +from ..finch_assembly import ( + AssemblyStructFType, + BufferFType, + DictFType, + ImmutableStructFType, + MutableStructFType, + TupleFType, +) from ..symbolic import Context, Namespace, ScopedDict, fisinstance, ftype from ..util import config from ..util.cache import file_cache @@ -140,7 +148,7 @@ def construct_from_c(fmt, c_obj): return fmt.construct_from_c(c_obj) try: return query_property(fmt, "construct_from_c", "__attr__", c_obj) - except NotImplementedError: + except AttributeError: return fmt(c_obj) @@ -204,12 +212,14 @@ def construct_from_c(fmt, c_obj): register_property(t, "construct_from_c", "__attr__", lambda fmt, c_value: c_value) register_property(t, "numba_type", "__attr__", lambda t: t) + register_property( np.generic, "serialize_to_c", "__attr__", - lambda fmt, obj: np.ctypeslib.as_ctypes(obj), + lambda fmt, obj: np.ctypeslib.as_ctypes(np.array(obj)), ) + # pass by value -> no op register_property( np.generic, @@ -254,11 +264,9 @@ def __call__(self, *args): self.argtypes, args, serial_args, strict=False ): deserialize_from_c(type_, arg, serial_arg) - if hasattr(self.ret_type, "construct_from_c"): - return construct_from_c(res.ftype, res) if self.ret_type is type(None): return None - return self.ret_type(res) + return construct_from_c(self.ret_type, res) class CModule: @@ -315,7 +323,6 @@ def __call__(self, prgm): for func in prgm.funcs: match func: case asm.Function(asm.Variable(func_name, return_t), args, _): - return_t = c_type(return_t) arg_ts = [arg.result_format for arg in args] kern = CKernel(getattr(lib, func_name), return_t, arg_ts) kernels[func_name] = kern @@ -569,6 +576,9 @@ def __call__(self, prgm: asm.AssemblyNode): class CContext(Context): """ A class to represent a C environment. + + The context has functionality to track which datastructure definitions need + to get declared via the stc library. """ def __init__( @@ -597,12 +607,24 @@ def __init__( self.fptr = fptr self.types = types self.slots = slots + self.datastructures: dict[Hashable, Any] = {} def add_header(self, header): if header not in self._headerset: self.headers.append(header) self._headerset.add(header) + def add_datastructure(self, key: Hashable, handler: "Callable[[CContext], Any]"): + """ + Code to add a datastructure declaration. + This is the minimum required to prevent redundancy. + """ + if key in self.datastructures: + return + # at least mark something is there. + self.datastructures[key] = None + handler(self) + def emit_global(self): """ Emit the headers for the C code. @@ -792,6 +814,15 @@ def __call__(self, prgm: asm.AssemblyNode): case asm.Length(buf): buf = self.resolve(buf) return buf.result_format.c_length(self, buf) + case asm.LoadDict(map, idx): + map = self.resolve(map) + return map.result_format.c_loaddict(self, map, idx) + case asm.ExistsDict(map, idx): + map = self.resolve(map) + return map.result_format.c_existsdict(self, map, idx) + case asm.StoreDict(map, idx, val): + map = self.resolve(map) + return map.result_format.c_storedict(self, map, idx, val) case asm.Block(bodies): ctx_2 = self.block() for body in bodies: @@ -936,6 +967,34 @@ def construct_from_c(self, res): """ +class CDictFType(DictFType, CArgumentFType, ABC): + """ + Abstract base class for the ftype of dictionaries. The ftype defines how + the data in a Map is organized and accessed. + """ + + @abstractmethod + def c_existsdict(self, ctx, map, idx): + """ + Return C code which checks whether a given key exists in a map. + """ + ... + + @abstractmethod + def c_loaddict(self, ctx, map, idx): + """ + Return C code which gets a value corresponding to a certain key. + """ + ... + + @abstractmethod + def c_storedict(self, ctx, buffer, idx, value): + """ + Return C code which stores a certain value given a certain integer tuple key. + """ + ... + + class CBufferFType(BufferFType, CArgumentFType, ABC): """ Abstract base class for the ftype of datastructures. The ftype defines how @@ -1035,36 +1094,55 @@ def struct_c_type(fmt: AssemblyStructFType): return new_struct +""" +Note: When serializing any struct to C, it will get serialized to a struct with +no indirection. + +When you pass a struct into a kernel that expects a struct pointer, ctypes can +intelligently infer whether we are working with a pointer arg type (pass by +reference) or a non-pointer type (in which case it will immediately apply +indirection) +""" + register_property( - AssemblyStructFType, + MutableStructFType, "c_type", "__attr__", lambda fmt: ctypes.POINTER(struct_c_type(fmt)), ) +register_property( + ImmutableStructFType, "c_type", "__attr__", lambda fmt: struct_c_type(fmt) +) -def struct_c_getattr(fmt: AssemblyStructFType, ctx, obj, attr): - return f"{obj}->{attr}" +register_property( + MutableStructFType, + "c_getattr", + "__attr__", + lambda fmt, ctx, obj, attr: f"{obj}->{attr}", +) register_property( - AssemblyStructFType, + ImmutableStructFType, "c_getattr", "__attr__", - struct_c_getattr, + lambda fmt, ctx, obj, attr: f"{obj}.{attr}", ) -def struct_c_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): +def struct_mutable_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): ctx.emit(f"{ctx.feed}{obj}->{attr} = {val};") - return +# the equivalent for immutable is f"{ctx.feed}{obj}.{attr} = {val};" +# but we will not include that because it's bad. + register_property( - AssemblyStructFType, + MutableStructFType, "c_setattr", "__attr__", - struct_c_setattr, + struct_mutable_setattr, ) @@ -1092,18 +1170,23 @@ def serialize_tuple_to_c(fmt, obj): "__attr__", serialize_tuple_to_c, ) + + +def tuple_construct_from_c(fmt: TupleFType, c_struct): + args = [getattr(c_struct, name) for name in fmt.struct_fieldnames] + return tuple(args) + + register_property( TupleFType, "construct_from_c", "__attr__", - lambda fmt, obj, c_tuple: tuple(c_tuple), + tuple_construct_from_c, ) register_property( TupleFType, "c_type", "__attr__", - lambda fmt: ctypes.POINTER( - struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields)) - ), + lambda fmt: struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields)), ) diff --git a/src/finchlite/codegen/hashtable.py b/src/finchlite/codegen/hashtable.py new file mode 100644 index 00000000..beda3726 --- /dev/null +++ b/src/finchlite/codegen/hashtable.py @@ -0,0 +1,620 @@ +import ctypes +from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent +from typing import Any, NamedTuple, TypedDict + +import numba + +from finchlite.codegen.c import ( + CContext, + CDictFType, + CStackFType, + c_type, + construct_from_c, + load_shared_lib, + serialize_to_c, +) +from finchlite.codegen.numba_backend import ( + NumbaContext, + NumbaDictFType, + NumbaStackFType, + numba_type, +) +from finchlite.finch_assembly.map import Dict +from finchlite.finch_assembly.nodes import AssemblyExpression, Stack +from finchlite.finch_assembly.struct import AssemblyStructFType, TupleFType + +stcpath = Path(__file__).parent / "stc" / "include" +hashmap_h = stcpath / "stc" / "hashmap.h" + + +class NumbaDictFields(NamedTuple): + """ + This is a field that extracts out the map from the obj variable. Its + purpose is so that we can extract out map from obj in unpack, do + computations on the map variable, and re-insert that into obj in repack. + """ + + map: str + obj: str + + +class CDictFields(NamedTuple): + """ + TODO: for the C backend, we will pulling in a completely different library + to do the actual hash function implementation. Should we even try to + convert back? + """ + + map: str + obj: str + + +def _is_integer_tuple(tup, size): + if not isinstance(tup, tuple) or len(tup) != size: + return False + return all(isinstance(elt, int) for elt in tup) + + +def _int_tuple_ftype(size: int): + return TupleFType.from_tuple(tuple(int for _ in range(size))) + + +def _tuplify(ftype: AssemblyStructFType, obj): + assert isinstance(ftype, AssemblyStructFType) + return tuple([ftype.struct_getattr(obj, attr) for attr in ftype.struct_fieldnames]) + + +class CHashTableStruct(ctypes.Structure): + _fields_ = [ + ("map", ctypes.c_void_p), + ("obj", ctypes.py_object), + ] + + +class CHashMethods(TypedDict): + init: str + exists: str + load: str + store: str + cleanup: str + + +@dataclass +class CHashTableLibrary: + library: ctypes.CDLL + methods: CHashMethods + hmap_t: str + + +# implement the hash table datastructures +class CHashTable(Dict): + """ + CHashTable class that basically connects up to an STC library. + """ + + libraries: dict[ + tuple[AssemblyStructFType, AssemblyStructFType], CHashTableLibrary + ] = {} + + @classmethod + def gen_code( + cls, + ctx: "CContext", + key_type: "AssemblyStructFType", + value_type: "AssemblyStructFType", + inline: bool = False, + ) -> tuple[CHashMethods, str]: + # dereference both key and value types; as given, they are both pointers. + keytype_c = ctx.ctype_name(c_type(key_type)) + valuetype_c = ctx.ctype_name(c_type(value_type)) + hmap_t = ctx.freshen("hmap") + + ctx.add_header("#include ") + + # these headers should just be added to the headers list. + # deduplication is catastrophic here. + ctx.headers.append(f"#define T {hmap_t}, {keytype_c}, {valuetype_c}") + ctx.headers.append("#define i_eq c_memcmp_eq") + ctx.headers.append(f'#include "{hashmap_h}"') + + methods: CHashMethods = { + "init": ctx.freshen("finch_hmap_init"), + "exists": ctx.freshen("finch_hmap_exists"), + "load": ctx.freshen("finch_hmap_load"), + "store": ctx.freshen("finch_hmap_store"), + "cleanup": ctx.freshen("finch_hmap_cleanup"), + } + # register these methods in the datastructures. + ctx.datastructures[CHashTableFType(key_type, value_type)] = methods + inline_s = "static inline " if inline else "" + + # basically for the load functions, you need to provide a variable that + # can be copied. + # Yeah, so which API's should we use for load and store? + lib_code = dedent( + f""" + {inline_s}void* + {methods["init"]}() {{ + void* ptr = malloc(sizeof({hmap_t})); + memset(ptr, 0, sizeof({hmap_t})); + return ptr; + }} + + {inline_s}bool + {methods["exists"]}( + {hmap_t} *map, {keytype_c} key + ) {{ + return {hmap_t}_contains(map, key); + }} + + {inline_s}{valuetype_c} + {methods["load"]}( + {hmap_t} *map, {keytype_c} key + ) {{ + const {valuetype_c}* internal_val = {hmap_t}_at(map, key); + return *internal_val; + }} + + {inline_s}void + {methods["store"]}( + {hmap_t} *map, {keytype_c} key, {valuetype_c} value + ) {{ + {hmap_t}_insert_or_assign(map, key, value); + }} + + {inline_s}void + {methods["cleanup"]}( + void* ptr + ) {{ + {hmap_t}* hptr = ptr; + {hmap_t}_drop(hptr); + free(hptr); + }} + """ + ) + ctx.add_header(lib_code) + + return methods, hmap_t + + @classmethod + def compile( + cls, key_type: AssemblyStructFType, value_type: AssemblyStructFType + ) -> CHashTableLibrary: + """ + compile a library to use for the c hash table. + """ + if (key_type, value_type) in cls.libraries: + return cls.libraries[(key_type, value_type)] + ctx = CContext() + methods, hmap_t = cls.gen_code(ctx, key_type, value_type) + code = ctx.emit_global() + lib = load_shared_lib(code) + + # get keystruct and value types + KeyStruct = c_type(key_type) + ValueStruct = c_type(value_type) + + init_func = getattr(lib, methods["init"]) + init_func.argtypes = [] + init_func.restype = ctypes.c_void_p + + # Exists: Takes (map*, key) -> returns bool + exists_func = getattr(lib, methods["exists"]) + exists_func.argtypes = [ctypes.c_void_p, KeyStruct] + exists_func.restype = ctypes.c_bool + + # Load: Takes (map*, key) -> returns value + load_func = getattr(lib, methods["load"]) + load_func.argtypes = [ + ctypes.c_void_p, + KeyStruct, + ] + load_func.restype = ValueStruct + + # Store: Takes (map*, key, val) -> returns void + store_func = getattr(lib, methods["store"]) + store_func.argtypes = [ + ctypes.c_void_p, + KeyStruct, + ValueStruct, + ] + store_func.restype = None + + # Cleanup: Takes (map*) -> returns void + cleanup_func = getattr(lib, methods["cleanup"]) + cleanup_func.argtypes = [ctypes.c_void_p] + cleanup_func.restype = None + + cls.libraries[(key_type, value_type)] = CHashTableLibrary(lib, methods, hmap_t) + return cls.libraries[(key_type, value_type)] + + def __init__(self, key_type, value_type, map: "dict | None" = None): + """ + Constructor for the C Hash Table + """ + self.lib = self.__class__.compile(key_type, value_type) + + # these are blank fields we need when serializing or smth + self._struct: Any = None + self._self_obj: Any = None + + self._key_type = key_type + self._value_type = value_type + + if map is None: + map = {} + self.dct = getattr(self.lib.library, self.lib.methods["init"])() + for key, value in map.items(): + # if some error happens, the serialization will handle it. + self.store(key, value) + + def __del__(self): + getattr(self.lib.library, self.lib.methods["cleanup"])(self.dct) + + def exists(self, idx) -> bool: + c_key = serialize_to_c(self.ftype.key_type, idx) + c_value = getattr(self.lib.library, self.lib.methods["exists"])(self.dct, c_key) + return bool(c_value) + + def load(self, idx): + c_key = serialize_to_c(self.ftype.key_type, idx) + c_value = getattr(self.lib.library, self.lib.methods["load"])(self.dct, c_key) + return construct_from_c(self.ftype.value_type, c_value) + + def store(self, idx, val): + c_key = serialize_to_c(self.ftype.key_type, idx) + c_value = serialize_to_c(self.ftype.value_type, val) + getattr(self.lib.library, self.lib.methods["store"])(self.dct, c_key, c_value) + + def __str__(self): + return f"c_hashtable({self.dct})" + + @property + def ftype(self): + return CHashTableFType(self._key_type, self._value_type) + + +class CHashTableFType(CDictFType, CStackFType): + """ + An implementation of Hash Tables using the stc library. + """ + + def __init__(self, key_type: AssemblyStructFType, value_type: AssemblyStructFType): + # these should both be immutable structs/POD types. + # we will enforce this once the immutable struct PR is merged. + self._key_type = key_type + self._value_type = value_type + + def __eq__(self, other): + if not isinstance(other, CHashTableFType): + return False + return self.key_type == other.key_type and self.value_type == other.value_type + + def __call__(self): + return CHashTable(self.key_type, self.value_type, {}) + + def __str__(self): + return f"chashtable_t({self.key_type}, {self.value_type})" + + def __repr__(self): + return f"CHashTableFType({self.key_type}, {self.value_type})" + + @property + def key_type(self): + """ + Returns the type of elements used as the keys of the hash table. + (some integer tuple) + """ + return self._key_type + + @property + def value_type(self): + """ + Returns the type of elements used as the value of the hash table. + (some integer tuple) + """ + return self._value_type + + def __hash__(self): + """ + This method needs to be here because you are going to be using this + type as a key in dictionaries. + """ + return hash(("CHashTableFType", self.key_type, self.value_type)) + + """ + Methods for the C Backend + This requires an external library (stc) to work. + """ + + def c_type(self): + return ctypes.POINTER(CHashTableStruct) + + def c_existsdict(self, ctx: "CContext", map: "Stack", idx: "AssemblyExpression"): + assert isinstance(map.obj, CDictFields) + methods: CHashMethods = ctx.datastructures[self] + return f"{ctx.feed}{methods['exists']}({map.obj.map}, {ctx(idx)})" + + def c_storedict( + self, + ctx: "CContext", + map: "Stack", + idx: "AssemblyExpression", + value: "AssemblyExpression", + ): + assert isinstance(map.obj, CDictFields) + methods: CHashMethods = ctx.datastructures[self] + ctx.exec( + f"{ctx.feed}{methods['store']}({map.obj.map}, {ctx(idx)}, {ctx(value)});" + ) + + def c_loaddict(self, ctx: "CContext", map: "Stack", idx: "AssemblyExpression"): + """ + Get an expression where we can get the value corresponding to a key. + """ + assert isinstance(map.obj, CDictFields) + methods: CHashMethods = ctx.datastructures[self] + + return f"{methods['load']}({map.obj.map}, {ctx(idx)})" + + def c_unpack(self, ctx: "CContext", var_n: str, val: AssemblyExpression): + """ + Unpack the map into C context. + """ + assert val.result_format == self + data = ctx.freshen(var_n, "data") + # Add all the stupid header stuff from above. + ctx.add_datastructure( + ("CHashTableFType", self.key_type, self.value_type), + lambda ctx: CHashTable.gen_code( + ctx, self.key_type, self.value_type, inline=True + ), + ) + + ctx.exec(f"{ctx.feed}void* {data} = {ctx(val)}->map;") + return CDictFields(data, var_n) + + def c_repack(self, ctx: "CContext", lhs: str, obj: "CDictFields"): + """ + Repack the map out of C context. + """ + ctx.exec(f"{ctx.feed}{lhs}->map = {obj.map};") + + def serialize_to_c(self, obj: CHashTable): + """ + Serialize the Hash Map to a CHashMap structure. + This datatype will then immediately get turned into a struct. + """ + assert isinstance(obj, CHashTable) + dct = ctypes.c_void_p(obj.dct) + struct = CHashTableStruct(dct, obj) + # We NEED this for stupid ownership reasons. + obj._self_obj = ctypes.py_object(obj) + obj._struct = struct + return ctypes.pointer(struct) + + def deserialize_from_c(self, obj: CHashTable, res): + """ + Update our hash table based on how the C call modified the CHashMapStruct. + """ + assert isinstance(res, ctypes.POINTER(CHashTableStruct)) + assert isinstance(res.contents.obj, CHashTable) + + obj.dct = res.contents.map + + def construct_from_c(self, c_dct): + """ + Construct a CHashTable from a C-compatible structure. + + c_map is a pointer to a CHashMapStruct + """ + raise NotImplementedError + + +class NumbaHashTable(Dict): + """ + A Hash Table that maps Z^{in_len} to Z^{out_len} + """ + + def __init__(self, key_len, value_len, dct: "dict[tuple,tuple] | None" = None): + self.key_len = key_len + self.value_len = value_len + + self._numba_key_type = numba.types.UniTuple(numba.types.int64, key_len) + self._numba_value_type = numba.types.UniTuple(numba.types.int64, value_len) + + if dct is None: + dct = {} + self.dct = numba.typed.Dict.empty( + key_type=self._numba_key_type, value_type=self._numba_value_type + ) + for key, value in dct.items(): + if not _is_integer_tuple(key, key_len): + raise TypeError( + f"Supplied key {key} is not a tuple of {key_len} integers" + ) + if not _is_integer_tuple(value, value_len): + raise TypeError( + f"Supplied value {key} is not a tuple of {value_len} integers" + ) + self.dct[key] = value + + @property + def ftype(self): + """ + Returns the finch type of this hash table. + """ + return NumbaHashTableFType(self.key_len, self.value_len) + + def exists(self, idx) -> bool: + """ + Exists function of the numba hash table. + It will accept an object with TupleFType and return a bool. + """ + idx = _tuplify(self.ftype.key_type, idx) + assert _is_integer_tuple(idx, self.key_len) + return idx in self.dct + + def load(self, idx): + idx = _tuplify(self.ftype.key_type, idx) + assert _is_integer_tuple(idx, self.key_len) + result = self.dct[idx] + return self.ftype.value_type(*result) + + def store(self, idx, val): + idx = _tuplify(self.ftype.key_type, idx) + val = _tuplify(self.ftype.value_type, val) + assert _is_integer_tuple(idx, self.key_len) + assert _is_integer_tuple(val, self.value_len) + self.dct[idx] = val + + def __str__(self): + return f"numba_hashtable({self.dct})" + + +class NumbaHashTableFType(NumbaDictFType, NumbaStackFType): + """ + An implementation of Hash Tables using the stc library. + """ + + def __init__(self, key_len: int, value_len: int): + self.key_len = key_len + self.value_len = value_len + self._key_type = _int_tuple_ftype(key_len) + self._value_type = _int_tuple_ftype(value_len) + self._numba_key_type = numba.types.UniTuple(numba.types.int64, key_len) + self._numba_value_type = numba.types.UniTuple(numba.types.int64, value_len) + + def __eq__(self, other): + if not isinstance(other, NumbaHashTableFType): + return False + return self.key_len == other.key_len and self.value_len == other.value_len + + def __call__(self): + return NumbaHashTable(self.key_len, self.value_len, {}) + + def __str__(self): + return f"numba_hashtable_t({self.key_len}, {self.value_len})" + + def __repr__(self): + return f"HashTableFType({self.key_len}, {self.value_len})" + + @property + def key_type(self): + """ + Returns the type of elements used as the keys of the hash table. + (some integer tuple) + """ + return self._key_type + + @property + def value_type(self): + """ + Returns the type of elements used as the value of the hash table. + (some integer tuple) + """ + return self._value_type + + def __hash__(self): + """ + This method needs to be here because you are going to be using this + type as a key in dictionaries. + """ + return hash(("NumbaHashTableFType", self.key_len, self.value_len)) + + """ + Methods for the Numba Backend + """ + + def numba_jitclass_type(self) -> numba.types.Type: + return numba.types.ListType( + numba.types.DictType(self._numba_key_type, self._numba_value_type) + ) + + def numba_type(self): + return list + + def numba_existsdict( + self, ctx: "NumbaContext", map: "Stack", idx: "AssemblyExpression" + ): + assert isinstance(map.obj, NumbaDictFields) + tuple_fields = ",".join( + f"{ctx(idx)}.{field}" for field in self.key_type.struct_fieldnames + ) + return f"tuple(({tuple_fields})) in {map.obj.map}" + + def numba_loaddict( + self, ctx: "NumbaContext", map: "Stack", idx: "AssemblyExpression" + ): + assert isinstance(map.obj, NumbaDictFields) + tuple_fields = ",".join( + f"{ctx(idx)}.{field}" for field in self.key_type.struct_fieldnames + ) + value_v = ctx.freshen("value") + ctx.exec(f"{ctx.feed}{value_v} = {map.obj.map}[tuple(({tuple_fields}))]") + return f"{ctx.full_name(numba_type(self.value_type))}(*{value_v})" + + def numba_storedict( + self, + ctx: "NumbaContext", + map: "Stack", + idx: "AssemblyExpression", + value: "AssemblyExpression", + ): + assert isinstance(map.obj, NumbaDictFields) + idx_fields = ",".join( + f"{ctx(idx)}.{field}" for field in self.key_type.struct_fieldnames + ) + val_fields = ",".join( + f"{ctx(value)}.{field}" for field in self.value_type.struct_fieldnames + ) + ctx.exec( + f"{ctx.feed}{map.obj.map}[tuple(({idx_fields}))] = tuple(({val_fields}))" + ) + + def numba_unpack( + self, ctx: "NumbaContext", var_n: str, val: "AssemblyExpression" + ) -> NumbaDictFields: + """ + Unpack the map into numba context. + """ + # the val field will always be asm.Variable(var_n, var_t) + map = ctx.freshen(var_n, "map") + ctx.exec(f"{ctx.feed}{map} = {ctx(val)}[0]") + + return NumbaDictFields(map, var_n) + + def numba_repack(self, ctx: "NumbaContext", lhs: str, obj: "NumbaDictFields"): + """ + Repack the map from Numba context. + """ + # obj is the fields corresponding to the self.slots[lhs] + ctx.exec(f"{ctx.feed}{lhs}[0] = {obj.map}") + + def serialize_to_numba(self, obj: "NumbaHashTable"): + """ + Serialize the hashmap to a Numba-compatible object. + + We will supply the input and output length + """ + return numba.typed.List([obj.dct]) + + def deserialize_from_numba(self, obj: "NumbaHashTable", numba_map: "list[dict]"): + obj.dct = numba_map[0] + + def construct_from_numba(self, numba_map): + """ + Construct a numba map from a Numba-compatible object. + """ + return NumbaHashTable(self.key_len, self.value_len, numba_map[0]) + + +if __name__ == "__main__": + table = CHashTable(2, 3, {(1, 2): (1, 4, 3)}) + print(c_type(table.ftype.key_type)) + table.store((2, 3), (3, 2, 3)) + print(table.exists((2, 3))) + print(table.load((2, 3))) + print(table.exists((2, 1))) diff --git a/src/finchlite/codegen/numba_backend.py b/src/finchlite/codegen/numba_backend.py index 9ff9f8e9..3da2c6e5 100644 --- a/src/finchlite/codegen/numba_backend.py +++ b/src/finchlite/codegen/numba_backend.py @@ -6,7 +6,9 @@ import numpy as np -import numba # type: ignore[import-untyped] +import numba + +from finchlite.finch_assembly.map import DictFType # type: ignore[import-untyped] from .. import finch_assembly as asm from ..algebra import query_property, register_property @@ -34,7 +36,10 @@ def numba_type(t): """ if hasattr(t, "numba_type"): return t.numba_type() - return query_property(t, "numba_type", "__attr__") + try: + return query_property(t, "numba_type", "__attr__") + except AttributeError: + return t def numba_jitclass_type(t): @@ -123,6 +128,13 @@ def assembly_struct_numba_jitclass_type(ftype_) -> numba.types.Type: lambda t: numba.from_dtype(t), ) +register_property( + int, + "numba_jitclass_type", + "__attr__", + lambda t: numba.int32, +) + register_property( AssemblyStructFType, "numba_jitclass_type", @@ -253,6 +265,34 @@ def construct_from_numba(fmt, numba_obj): ) +class NumbaDictFType(DictFType, NumbaArgumentFType, ABC): + """ + Abstract base class for the ftype of datastructures. The ftype defines how + the data in a Map is organized and accessed. + """ + + @abstractmethod + def numba_existsdict(self, ctx: "NumbaContext", map, idx): + """ + Return numba code which checks whether a given key exists in a map. + """ + ... + + @abstractmethod + def numba_loaddict(self, ctx, buffer, idx): + """ + Return numba code which gets a value corresponding to a certain key. + """ + ... + + @abstractmethod + def numba_storedict(self, ctx, buffer, idx, value): + """ + Return C code which stores a certain value given a certain integer tuple key. + """ + ... + + class NumbaBufferFType(BufferFType, NumbaArgumentFType, ABC): @abstractmethod def numba_length(self, ctx: "NumbaContext", buffer): @@ -548,6 +588,15 @@ def __call__(self, prgm: asm.AssemblyNode): case asm.Length(buf): buf = self.resolve(buf) return buf.result_format.numba_length(self, buf) + case asm.LoadDict(dct, idx): + dct = self.resolve(dct) + return dct.result_format.numba_loaddict(self, dct, idx) + case asm.ExistsDict(dct, idx): + dct = self.resolve(dct) + return dct.result_format.numba_existsdict(self, dct, idx) + case asm.StoreDict(dct, idx, val): + dct = self.resolve(dct) + return dct.result_format.numba_storedict(self, dct, idx, val) case asm.Block(bodies): ctx_2 = self.block() for body in bodies: @@ -751,6 +800,23 @@ def serialize_tuple_to_numba(fmt, obj): serialize_tuple_to_numba, ) +# trivial ser/deser +for t in (int, bool, float): + register_property( + t, + "construct_from_numba", + "__attr__", + lambda fmt, obj: obj, + ) + + register_property( + t, + "serialize_to_numba", + "__attr__", + lambda fmt, obj: obj, + ) + + register_property( operator.add, "numba_literal", diff --git a/src/finchlite/codegen/numpy_buffer.py b/src/finchlite/codegen/numpy_buffer.py index 0e3f308b..c567e595 100644 --- a/src/finchlite/codegen/numpy_buffer.py +++ b/src/finchlite/codegen/numpy_buffer.py @@ -5,9 +5,11 @@ import numba +from finchlite.finch_assembly.nodes import AssemblyExpression, Stack + from ..finch_assembly import Buffer from ..util import qual_str -from .c import CBufferFType, CStackFType, c_type +from .c import CBufferFType, CContext, CStackFType, c_type from .numba_backend import NumbaBufferFType @@ -123,16 +125,26 @@ def __call__(self, len: int = 0, dtype: type | None = None): def c_type(self): return ctypes.POINTER(CNumpyBuffer) - def c_length(self, ctx, buf): + def c_length(self, ctx: "CContext", buf: "Stack"): + assert isinstance(buf.obj, CBufferFields) return buf.obj.length - def c_data(self, ctx, buf): + def c_data(self, ctx: "CContext", buf: "Stack"): + assert isinstance(buf.obj, CBufferFields) return buf.obj.data - def c_load(self, ctx, buf, idx): + def c_load(self, ctx: "CContext", buf: "Stack", idx: "AssemblyExpression"): + assert isinstance(buf.obj, CBufferFields) return f"({buf.obj.data})[{ctx(idx)}]" - def c_store(self, ctx, buf, idx, value): + def c_store( + self, + ctx: "CContext", + buf: "Stack", + idx: "AssemblyExpression", + value: "AssemblyExpression", + ): + assert isinstance(buf.obj, CBufferFields) ctx.exec(f"{ctx.feed}({buf.obj.data})[{ctx(idx)}] = {ctx(value)};") def c_resize(self, ctx, buf, new_len): diff --git a/src/finchlite/codegen/stc b/src/finchlite/codegen/stc new file mode 160000 index 00000000..19db5051 --- /dev/null +++ b/src/finchlite/codegen/stc @@ -0,0 +1 @@ +Subproject commit 19db5051157da5525b3b4cfced4e33810985bd94 diff --git a/src/finchlite/finch_assembly/__init__.py b/src/finchlite/finch_assembly/__init__.py index a2417d22..1e27c346 100644 --- a/src/finchlite/finch_assembly/__init__.py +++ b/src/finchlite/finch_assembly/__init__.py @@ -1,11 +1,8 @@ from .buffer import Buffer, BufferFType, element_type, length_type -from .cfg_builder import ( - AssemblyCFGBuilder, - assembly_build_cfg, - assembly_number_uses, -) +from .cfg_builder import AssemblyCFGBuilder, assembly_build_cfg, assembly_number_uses from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel +from .map import Dict, DictFType from .nodes import ( AssemblyNode, Assert, @@ -14,6 +11,7 @@ Break, BufferLoop, Call, + ExistsDict, ForLoop, Function, GetAttr, @@ -22,6 +20,7 @@ Length, Literal, Load, + LoadDict, Module, Print, Repack, @@ -31,12 +30,19 @@ Slot, Stack, Store, + StoreDict, TaggedVariable, Unpack, Variable, WhileLoop, ) -from .struct import AssemblyStructFType, NamedTupleFType, TupleFType +from .struct import ( + AssemblyStructFType, + ImmutableStructFType, + MutableStructFType, + NamedTupleFType, + TupleFType, +) from .type_checker import AssemblyTypeChecker, AssemblyTypeError, assembly_check_types __all__ = [ @@ -56,15 +62,21 @@ "BufferFType", "BufferLoop", "Call", + "Dict", + "DictFType", + "ExistsDict", "ForLoop", "Function", "GetAttr", "If", "IfElse", + "ImmutableStructFType", "Length", "Literal", "Load", + "LoadDict", "Module", + "MutableStructFType", "NamedTupleFType", "Print", "Repack", @@ -74,6 +86,7 @@ "Slot", "Stack", "Store", + "StoreDict", "TaggedVariable", "TupleFType", "Unpack", diff --git a/src/finchlite/finch_assembly/dev_doc.md b/src/finchlite/finch_assembly/dev_doc.md index 66ee299f..8327cb9d 100644 --- a/src/finchlite/finch_assembly/dev_doc.md +++ b/src/finchlite/finch_assembly/dev_doc.md @@ -11,7 +11,7 @@ The following is a rough grammar for FinchAssembly, written in terms of the curr EXPR := LITERAL | VARIABLE | SLOT | STACK | GETATTR | CALL | LOAD | LENGTH STMT := UNPACK | REPACK | ASSIGN | SETATTR | STORE | RESIZE | FORLOOP | BUFFERLOOP | WHILELOOP | IF | IFELSE | FUNCTION | RETURN | BREAK - | BLOCK | MODULE + | BLOCK | MODULE | LOADMAP | STOREMAP | EXISTSMAP NODE := EXPR | STMT LITERAL := Literal(val=VALUE) @@ -27,6 +27,9 @@ LOAD := Load(buffer=SLOT | STACK, index=EXPR) STORE := Store(buffer=SLOT | STACK, index=EXPR, value=EXPR) RESIZE := Resize(buffer=SLOT | STACK, new_size=EXPR) LENGTH := Length(buffer=SLOT | STACK) +LOADMAP := LoadMap(map=SLOT | STACK, index=EXPR) +STOREMAP := StoreMap(map=SLOT | STACK, index=EXPR, value=EXPR) +EXISTSMAP := ExistMap(map=SLOT | STACK, index=EXPR) STACK := Stack(obj=ANY, type=TYPE) FORLOOP := ForLoop(var=VARIABLE, start=EXPR, end=EXPR, body=NODE) BUFFERLOOP := BufferLoop(buffer=EXPR, var=VARIABLE, body=NODE) diff --git a/src/finchlite/finch_assembly/interpreter.py b/src/finchlite/finch_assembly/interpreter.py index 0f515502..8dd08104 100644 --- a/src/finchlite/finch_assembly/interpreter.py +++ b/src/finchlite/finch_assembly/interpreter.py @@ -212,6 +212,16 @@ def __call__(self, prgm: asm.AssemblyNode): buf_e = self(buf) idx_e = self(idx) return buf_e.load(idx_e) + case asm.LoadDict(dct, idx): + assert isinstance(dct, asm.Slot) + map_e = self(dct) + idx_e = self(idx) + return map_e.load(idx_e) + case asm.ExistsDict(dct, idx): + assert isinstance(dct, asm.Slot) + map_e = self(dct) + idx_e = self(idx) + return map_e.exists(idx_e) case asm.Store(buf, idx, val): assert isinstance(buf, asm.Slot) buf_e = self(buf) @@ -219,6 +229,12 @@ def __call__(self, prgm: asm.AssemblyNode): val_e = self(val) buf_e.store(idx_e, val_e) return None + case asm.StoreDict(dct, idx, val): + assert isinstance(dct, asm.Slot) + map_e = self(dct) + idx_e = self(idx) + val_e = self(val) + return map_e.store(idx_e, val_e) case asm.Resize(buf, len_): assert isinstance(buf, asm.Slot) buf_e = self(buf) @@ -309,7 +325,7 @@ def my_func(*args_e): ctx_2(body) if ctx_2.function_state.should_halt: ret_e = ctx_2.function_state.return_value - if not check_isinstance(ret_e, ret_t): + if not fisinstance(ret_e, ret_t): raise TypeError( f"Return value {ret_e} is not of type {ret_t} " f"for function '{func_n}'." diff --git a/src/finchlite/finch_assembly/map.py b/src/finchlite/finch_assembly/map.py new file mode 100644 index 00000000..be233a0e --- /dev/null +++ b/src/finchlite/finch_assembly/map.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod + +from ..symbolic import FType, FTyped + + +class Dict(FTyped, ABC): + """ + Abstract base class for a map data structure. + Hash tables should be such that their bucket size can be resized, with Tree + maps turning that into a no-op. + """ + + @abstractmethod + def __init__( + self, key_len: int, value_len: int, map: "dict[tuple,tuple] | None" + ): ... + + @property + @abstractmethod + def ftype(self) -> "DictFType": ... + + @property + def value_type(self): + """ + Return type of values stored in the hash table + (probably some TupleFType) + """ + return self.ftype.value_type + + @property + def key_type(self): + """ + Return type of keys stored in the hash table + (probably some TupleFType) + """ + return self.ftype.key_type + + @abstractmethod + def load(self, idx: tuple): + """ + Method to access some element in the map. Will panic if the key doesn't exist. + """ + ... + + @abstractmethod + def exists(self, idx: tuple) -> bool: + """ + Method to check if the element exists in the map. + """ + ... + + @abstractmethod + def store(self, idx: tuple, val): + """ + Method to store elements in the map. Ideally it should just create new + elements. + """ + ... + + +class DictFType(FType): + """ + Abstract base class for an ftype corresponding to a map. + """ + + @abstractmethod + def __call__(self, *args, **kwargs): + """ + Create an instance of an object in this ftype with the given arguments. + """ + ... + + @property + @abstractmethod + def value_type(self): + """ + Return the type of elements stored in the map. + This is typically the same as the dtype used to create the map. + """ + ... + + @property + @abstractmethod + def key_type(self): + """ + Returns the type used for the length of the map. + """ + ... diff --git a/src/finchlite/finch_assembly/nodes.py b/src/finchlite/finch_assembly/nodes.py index bd1503bd..b380c228 100644 --- a/src/finchlite/finch_assembly/nodes.py +++ b/src/finchlite/finch_assembly/nodes.py @@ -366,6 +366,68 @@ def children(self): return [self.buffer, self.index, self.value] +@dataclass(eq=True, frozen=True) +class ExistsDict(AssemblyExpression, AssemblyTree): + """ + Represents checking whether an integer tuple key is in a map. + + Attributes: + map: The map to load from. + index: The key to check for existence. + """ + + map: Slot | Stack + index: AssemblyExpression + + @property + def children(self): + return [self.map, self.index] + + def result_format(self): + return bool + + +@dataclass(eq=True, frozen=True) +class LoadDict(AssemblyExpression, AssemblyTree): + """ + Represents loading a value from a map given an integer tuple key. + + Attributes: + map: The map to load from. + index: The key value + """ + + dct: Slot | Stack + index: AssemblyExpression + + @property + def children(self): + return [self.dct, self.index] + + def result_format(self): + return self.dct.result_format.value_type + + +@dataclass(eq=True, frozen=True) +class StoreDict(AssemblyTree, AssemblyStatement): + """ + Represents storing a value into a buffer given an integer tuple key. + + Attributes: + map: The map to load from. + index1: The first integer in the pair + index2: The second integer in the pair + """ + + map: Slot | Stack + index: AssemblyExpression + value: AssemblyExpression + + @property + def children(self): + return [self.map, self.index, self.value] + + @dataclass(eq=True, frozen=True) class Resize(AssemblyTree, AssemblyStatement): """ @@ -707,11 +769,18 @@ def __call__(self, prgm: AssemblyNode): return None case Load(buf, idx): return f"load({self(buf)}, {self(idx)})" + case LoadDict(map, idx): + return f"loadmap({self(map)}, {self(idx)})" + case ExistsDict(map, idx): + return f"existsmap({self(map)}, {self(idx)})" case Slot(name, type_): return f"slot({name}, {qual_str(type_)})" case Store(buf, idx, val): self.exec(f"{feed}store({self(buf)}, {self(idx)}, {self(val)})") return None + case StoreDict(map, idx, val): + self.exec(f"{feed}storemap({self(map)}, {self(idx)}, {self(val)})") + return None case Resize(buf, size): self.exec(f"{feed}resize({self(buf)}, {self(size)})") return None diff --git a/src/finchlite/finch_assembly/struct.py b/src/finchlite/finch_assembly/struct.py index b855f87e..f2bb138e 100644 --- a/src/finchlite/finch_assembly/struct.py +++ b/src/finchlite/finch_assembly/struct.py @@ -45,7 +45,24 @@ def struct_attrtype(self, attr: str) -> Any: return dict(self.struct_fields)[attr] -class NamedTupleFType(AssemblyStructFType): +class ImmutableStructFType(AssemblyStructFType): + @property + def is_mutable(self) -> bool: + return False + + +class MutableStructFType(AssemblyStructFType): + """ + Class for a mutable assembly struct type. + It is currently not used anywhere, but maybe it will be useful in the future? + """ + + @property + def is_mutable(self) -> bool: + return True + + +class NamedTupleFType(ImmutableStructFType): def __init__(self, struct_name, struct_fields): self._struct_name = struct_name self._struct_fields = struct_fields @@ -79,7 +96,7 @@ def __call__(self, *args): return namedtuple(self.struct_name, self.struct_fieldnames)(args) -class TupleFType(AssemblyStructFType): +class TupleFType(ImmutableStructFType): def __init__(self, struct_name, struct_formats): self._struct_name = struct_name self._struct_formats = struct_formats diff --git a/src/finchlite/finch_assembly/type_checker.py b/src/finchlite/finch_assembly/type_checker.py index 6cec1f43..aab93982 100644 --- a/src/finchlite/finch_assembly/type_checker.py +++ b/src/finchlite/finch_assembly/type_checker.py @@ -6,6 +6,7 @@ from ..symbolic import FType, ScopedDict, ftype from . import nodes as asm from .buffer import BufferFType +from .map import DictFType from .struct import AssemblyStructFType @@ -81,6 +82,12 @@ def check_in_ctxt(self, var_n, var_t): f"The variable '{var_n}' is not defined in the current context." ) from KeyError + def check_dict(self, dct): + map_type = self.check_expr(dct) + if isinstance(map_type, DictFType): + return map_type + raise AssemblyTypeError(f"Expected map, got {map_type}.") + def check_buffer(self, buffer): buffer_type = self.check_expr(buffer) if isinstance(buffer_type, BufferFType): @@ -134,6 +141,16 @@ def check_expr(self, expr: asm.AssemblyExpression): case asm.Length(buffer): buffer_type = self.check_buffer(buffer) return buffer_type.length_type + case asm.ExistsDict(dct, index): + map_type = self.check_dict(dct) + index_type = self.check_expr(index) + check_type_match(map_type.key_type, index_type) + return bool + case asm.LoadDict(dct, index): + map_type = self.check_dict(dct) + index_type = self.check_expr(index) + check_type_match(map_type.key_type, index_type) + return map_type.value_type case _: raise ValueError(f"Ill-formed AssemblyExpression: {type(expr)}.") @@ -181,6 +198,13 @@ def check_stmt(self, stmt: asm.AssemblyStatement): value_type = self.check_expr(value) check_type_match(buffer_type.element_type, value_type) return None + case asm.StoreDict(map, index, value): + map_type = self.check_dict(map) + index_type = self.check_expr(index) + value_type = self.check_expr(value) + check_type_match(map_type.key_type, index_type) + check_type_match(map_type.value_type, value_type) + return None case asm.Resize(buffer, new_size): buffer_type = self.check_buffer(buffer) new_size_type = self.check_expr(new_size) diff --git a/tests/scripts/safebufferaccess.py b/tests/scripts/safebufferaccess.py index af133cce..9e279e03 100755 --- a/tests/scripts/safebufferaccess.py +++ b/tests/scripts/safebufferaccess.py @@ -9,7 +9,6 @@ """ import argparse -import ctypes import numpy as np @@ -25,22 +24,22 @@ subparser = parser.add_subparsers(required=True, dest="subparser_name") load = subparser.add_parser("load", help="attempt to load some element") -load.add_argument("index", type=int, help="the index to load") +load.add_argument("index", type=np.intp, help="the index to load") store = subparser.add_parser("store", help="attempt to store into some element") -store.add_argument("index", type=int, help="the index to load") -store.add_argument("value", type=int, help="the value to store") +store.add_argument("index", type=np.intp, help="the index to load") +store.add_argument("value", type=np.int64, help="the value to store") args = parser.parse_args() -a = np.array(range(args.size), dtype=ctypes.c_int64) +a = np.array(range(args.size), dtype=np.int64) ab = NumpyBuffer(a) ab_safe = SafeBuffer(ab) ab_v = asm.Variable("a", ab_safe.ftype) ab_slt = asm.Slot("a_", ab_safe.ftype) -idx = asm.Variable("idx", ctypes.c_size_t) -val = asm.Variable("val", ctypes.c_int64) +idx = asm.Variable("idx", np.intp) +val = asm.Variable("val", np.int64) res_var = asm.Variable("val", ab_safe.ftype.element_type) res_var2 = asm.Variable("val2", ab_safe.ftype.element_type) @@ -64,6 +63,7 @@ res_var2, asm.Load(ab_slt, idx), ), + asm.Repack(ab_slt), asm.Return(res_var), ) ), @@ -79,7 +79,8 @@ idx, val, ), - asm.Return(asm.Literal(ctypes.c_int64(0))), + asm.Repack(ab_slt), + asm.Return(asm.Literal(0)), ) ), ), @@ -91,8 +92,8 @@ match args.subparser_name: case "load": - print(access(ab_safe, ctypes.c_size_t(args.index)).value) + print(access(ab_safe, args.index)) case "store": - change(ab_safe, ctypes.c_size_t(args.index), ctypes.c_int64(args.value)) + change(ab_safe, args.index, args.value) arr = [str(ab_safe.load(i)) for i in range(args.size)] print(f"[{' '.join(arr)}]") diff --git a/tests/test_assembly_type_checker.py b/tests/test_assembly_type_checker.py index 7ecf0105..13fea56a 100644 --- a/tests/test_assembly_type_checker.py +++ b/tests/test_assembly_type_checker.py @@ -7,7 +7,9 @@ import finchlite.finch_assembly as asm from finchlite.codegen import NumpyBuffer +from finchlite.codegen.hashtable import CHashTable, NumbaHashTable from finchlite.finch_assembly import assembly_check_types +from finchlite.finch_assembly.struct import TupleFType from finchlite.symbolic import FType, ftype @@ -684,3 +686,109 @@ def test_simple_struct(): ) assembly_check_types(mod) + + +@pytest.mark.parametrize( + ["constructor"], + [ + ( + lambda: CHashTable( + asm.TupleFType.from_tuple((int, int)), + asm.TupleFType.from_tuple((int, int, int)), + ), + ), + (lambda: NumbaHashTable(2, 3),), + ], +) +def test_hashtable(constructor): + table = constructor() + + table_v = asm.Variable("a", ftype(table)) + table_slt = asm.Slot("a_", ftype(table)) + + key_type = table.ftype.key_type + val_type = table.ftype.value_type + key_v = asm.Variable("key", key_type) + val_v = asm.Variable("val", val_type) + + mod = asm.Module( + ( + asm.Function( + asm.Variable( + "setidx", TupleFType.from_tuple(tuple(int for _ in range(3))) + ), + (table_v, key_v, val_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.StoreDict( + table_slt, + key_v, + val_v, + ), + asm.Repack(table_slt), + asm.Return(asm.LoadDict(table_slt, key_v)), + ) + ), + ), + asm.Function( + asm.Variable("exists", bool), + (table_v, key_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.Return(asm.ExistsDict(table_slt, key_v)), + ) + ), + ), + ) + ) + assembly_check_types(mod) + + +@pytest.mark.parametrize( + ["constructor"], + [ + ( + lambda: CHashTable( + asm.TupleFType.from_tuple((int, int)), + asm.TupleFType.from_tuple((int, int, int)), + ), + ), + (lambda: NumbaHashTable(2, 3),), + ], +) +def test_hashtable_fail(constructor): + table = constructor() + + table_v = asm.Variable("a", ftype(table)) + table_slt = asm.Slot("a_", ftype(table)) + + key_type = table.ftype.key_type + val_type = table.ftype.value_type + key_v = asm.Variable("key", key_type) + val_v = asm.Variable("val", val_type) + mod = asm.Module( + ( + asm.Function( + asm.Variable( + "setidx", TupleFType.from_tuple(tuple(int for _ in range(2))) + ), + (table_v, key_v, val_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.StoreDict( + table_slt, + key_v, + val_v, + ), + asm.Repack(table_slt), + asm.Return(asm.LoadDict(table_slt, key_v)), + ) + ), + ), + ) + ) + with pytest.raises(asm.AssemblyTypeError): + assembly_check_types(mod) diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 202c406e..10ea2f84 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -24,6 +24,7 @@ SafeBuffer, ) from finchlite.codegen.c import construct_from_c, deserialize_from_c, serialize_to_c +from finchlite.codegen.hashtable import CHashTable, NumbaHashTable from finchlite.codegen.malloc_buffer import MallocBuffer from finchlite.codegen.numba_backend import ( construct_from_numba, @@ -260,7 +261,7 @@ def test_malloc_resize(new_size): ) ) mod = CCompiler()(prgm) - assert mod.length(ab).value == new_size + assert mod.length(ab) == new_size assert ab.length() == new_size for i in range(new_size): assert ab.load(i) == 0 if i >= len(a) else a[i] @@ -925,3 +926,144 @@ def test_e2e_numba(): assert_equal(result, a @ b) finchlite.set_default_scheduler(ctx=ctx) + + +@pytest.mark.parametrize( + ["compiler", "constructor"], + [ + ( + CCompiler(), + lambda: CHashTable( + asm.TupleFType.from_tuple((int, int)), + asm.TupleFType.from_tuple((int, int, int)), + ), + ), + ( + asm.AssemblyInterpreter(), + lambda: CHashTable( + asm.TupleFType.from_tuple((int, int)), + asm.TupleFType.from_tuple((int, int, int)), + ), + ), + (NumbaCompiler(), lambda: NumbaHashTable(2, 3)), + (asm.AssemblyInterpreter(), lambda: NumbaHashTable(2, 3)), + ], +) +def test_hashtable(compiler, constructor): + table = constructor() + + table_v = asm.Variable("a", ftype(table)) + table_slt = asm.Slot("a_", ftype(table)) + + key_type = table.ftype.key_type + val_type = table.ftype.value_type + key_v = asm.Variable("key", key_type) + val_v = asm.Variable("val", val_type) + + module = asm.Module( + ( + asm.Function( + asm.Variable("setidx", val_type), + (table_v, key_v, val_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.StoreDict( + table_slt, + key_v, + val_v, + ), + asm.Repack(table_slt), + asm.Return(asm.LoadDict(table_slt, key_v)), + ) + ), + ), + asm.Function( + asm.Variable("exists", bool), + (table_v, key_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.Return(asm.ExistsDict(table_slt, key_v)), + ) + ), + ), + ) + ) + compiled = compiler(module) + assert compiled.setidx(table, key_type(1, 2), val_type(2, 3, 4)) == val_type( + 2, 3, 4 + ) + assert compiled.setidx(table, key_type(1, 4), val_type(3, 4, 1)) == val_type( + 3, 4, 1 + ) + assert compiled.exists(table, key_type(1, 2)) + + assert not compiled.exists(table, key_type(1, 3)) + assert not compiled.exists(table, val_type(2, 3)) + + +@pytest.mark.parametrize( + ["compiler"], + [ + (CCompiler(),), + (asm.AssemblyInterpreter(),), + ], +) +def test_multiple_c_hashtable(compiler): + """ + This test exists because in the case of C, we might need to dump multiple + hash table definitions into the context. + + So I am not gonna touch heterogeneous structs right now because the hasher + hashes the padding bytes too (even though they are worse than useless) + """ + + def _int_tupletype(arity): + return asm.TupleFType.from_tuple(tuple(int for _ in range(arity))) + + def func(table: CHashTable, num: int): + key_type = table.ftype.key_type + val_type = table.ftype.value_type + key_v = asm.Variable("key", key_type) + val_v = asm.Variable("val", val_type) + table_v = asm.Variable("a", ftype(table)) + table_slt = asm.Slot("a_", ftype(table)) + return asm.Function( + asm.Variable(f"setidx_{num}", val_type), + (table_v, key_v, val_v), + asm.Block( + ( + asm.Unpack(table_slt, table_v), + asm.StoreDict( + table_slt, + key_v, + val_v, + ), + asm.Repack(table_slt), + asm.Return(asm.LoadDict(table_slt, key_v)), + ) + ), + ) + + table1 = CHashTable(_int_tupletype(2), _int_tupletype(3)) + table2 = CHashTable(_int_tupletype(1), _int_tupletype(4)) + + mod = compiler( + asm.Module( + ( + func(table1, 1), + func(table2, 2), + ) + ) + ) + + # what's important here is that you can call setidx_1 on table1 and + # setidx_2 on table2. + assert mod.setidx_1( + table1, table1.key_type(1, 2), table1.value_type(2, 3, 4) + ) == table1.value_type(2, 3, 4) + + assert mod.setidx_2( + table2, table2.key_type(1), table2.value_type(2, 3, 4, 5) + ) == table2.value_type(2, 3, 4, 5)