diff --git a/src/finchlite/codegen/c.py b/src/finchlite/codegen/c.py index f9b72b6d..953536ef 100644 --- a/src/finchlite/codegen/c.py +++ b/src/finchlite/codegen/c.py @@ -15,7 +15,13 @@ from .. import finch_assembly as asm from ..algebra import query_property, register_property -from ..finch_assembly import AssemblyStructFType, BufferFType, TupleFType +from ..finch_assembly import ( + AssemblyStructFType, + BufferFType, + ImmutableStructFType, + MutableStructFType, + TupleFType, +) from ..symbolic import Context, Namespace, ScopedDict, fisinstance, ftype from ..util import config from ..util.cache import file_cache @@ -140,7 +146,7 @@ def construct_from_c(fmt, c_obj): return fmt.construct_from_c(c_obj) try: return query_property(fmt, "construct_from_c", "__attr__", c_obj) - except NotImplementedError: + except AttributeError: return fmt(c_obj) @@ -204,12 +210,14 @@ def construct_from_c(fmt, c_obj): register_property(t, "construct_from_c", "__attr__", lambda fmt, c_value: c_value) register_property(t, "numba_type", "__attr__", lambda t: t) + register_property( np.generic, "serialize_to_c", "__attr__", - lambda fmt, obj: np.ctypeslib.as_ctypes(obj), + lambda fmt, obj: np.ctypeslib.as_ctypes(np.array(obj)), ) + # pass by value -> no op register_property( np.generic, @@ -254,11 +262,9 @@ def __call__(self, *args): self.argtypes, args, serial_args, strict=False ): deserialize_from_c(type_, arg, serial_arg) - if hasattr(self.ret_type, "construct_from_c"): - return construct_from_c(res.ftype, res) if self.ret_type is type(None): return None - return self.ret_type(res) + return construct_from_c(self.ret_type, res) class CModule: @@ -315,7 +321,7 @@ def __call__(self, prgm): for func in prgm.funcs: match func: case asm.Function(asm.Variable(func_name, return_t), args, _): - return_t = c_type(return_t) + # return_t = c_type(return_t) arg_ts = [arg.result_format for arg in args] kern = CKernel(getattr(lib, func_name), return_t, arg_ts) kernels[func_name] = kern @@ -1035,36 +1041,55 @@ def struct_c_type(fmt: AssemblyStructFType): return new_struct +""" +Note: When serializing any struct to C, it will get serialized to a struct with +no indirection. + +When you pass a struct into a kernel that expects a struct pointer, ctypes can +intelligently infer whether we are working with a pointer arg type (pass by +reference) or a non-pointer type (in which case it will immediately apply +indirection) +""" + register_property( - AssemblyStructFType, + MutableStructFType, "c_type", "__attr__", lambda fmt: ctypes.POINTER(struct_c_type(fmt)), ) +register_property( + ImmutableStructFType, "c_type", "__attr__", lambda fmt: struct_c_type(fmt) +) -def struct_c_getattr(fmt: AssemblyStructFType, ctx, obj, attr): - return f"{obj}->{attr}" +register_property( + MutableStructFType, + "c_getattr", + "__attr__", + lambda fmt, ctx, obj, attr: f"{obj}->{attr}", +) register_property( - AssemblyStructFType, + ImmutableStructFType, "c_getattr", "__attr__", - struct_c_getattr, + lambda fmt, ctx, obj, attr: f"{obj}.{attr}", ) -def struct_c_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): +def struct_mutable_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): ctx.emit(f"{ctx.feed}{obj}->{attr} = {val};") - return +# the equivalent for immutable is f"{ctx.feed}{obj}.{attr} = {val};" +# but we will not include that because it's bad. + register_property( - AssemblyStructFType, + MutableStructFType, "c_setattr", "__attr__", - struct_c_setattr, + struct_mutable_setattr, ) @@ -1092,18 +1117,23 @@ def serialize_tuple_to_c(fmt, obj): "__attr__", serialize_tuple_to_c, ) + + +def tuple_construct_from_c(fmt: TupleFType, c_struct): + args = [getattr(c_struct, name) for name in fmt.struct_fieldnames] + return tuple(args) + + register_property( TupleFType, "construct_from_c", "__attr__", - lambda fmt, obj, c_tuple: tuple(c_tuple), + tuple_construct_from_c, ) register_property( TupleFType, "c_type", "__attr__", - lambda fmt: ctypes.POINTER( - struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields)) - ), + lambda fmt: struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields)), ) diff --git a/src/finchlite/finch_assembly/__init__.py b/src/finchlite/finch_assembly/__init__.py index a2417d22..293510f6 100644 --- a/src/finchlite/finch_assembly/__init__.py +++ b/src/finchlite/finch_assembly/__init__.py @@ -1,9 +1,5 @@ from .buffer import Buffer, BufferFType, element_type, length_type -from .cfg_builder import ( - AssemblyCFGBuilder, - assembly_build_cfg, - assembly_number_uses, -) +from .cfg_builder import AssemblyCFGBuilder, assembly_build_cfg, assembly_number_uses from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel from .nodes import ( @@ -36,7 +32,13 @@ Variable, WhileLoop, ) -from .struct import AssemblyStructFType, NamedTupleFType, TupleFType +from .struct import ( + AssemblyStructFType, + ImmutableStructFType, + MutableStructFType, + NamedTupleFType, + TupleFType, +) from .type_checker import AssemblyTypeChecker, AssemblyTypeError, assembly_check_types __all__ = [ @@ -61,10 +63,12 @@ "GetAttr", "If", "IfElse", + "ImmutableStructFType", "Length", "Literal", "Load", "Module", + "MutableStructFType", "NamedTupleFType", "Print", "Repack", diff --git a/src/finchlite/finch_assembly/struct.py b/src/finchlite/finch_assembly/struct.py index b855f87e..f2bb138e 100644 --- a/src/finchlite/finch_assembly/struct.py +++ b/src/finchlite/finch_assembly/struct.py @@ -45,7 +45,24 @@ def struct_attrtype(self, attr: str) -> Any: return dict(self.struct_fields)[attr] -class NamedTupleFType(AssemblyStructFType): +class ImmutableStructFType(AssemblyStructFType): + @property + def is_mutable(self) -> bool: + return False + + +class MutableStructFType(AssemblyStructFType): + """ + Class for a mutable assembly struct type. + It is currently not used anywhere, but maybe it will be useful in the future? + """ + + @property + def is_mutable(self) -> bool: + return True + + +class NamedTupleFType(ImmutableStructFType): def __init__(self, struct_name, struct_fields): self._struct_name = struct_name self._struct_fields = struct_fields @@ -79,7 +96,7 @@ def __call__(self, *args): return namedtuple(self.struct_name, self.struct_fieldnames)(args) -class TupleFType(AssemblyStructFType): +class TupleFType(ImmutableStructFType): def __init__(self, struct_name, struct_formats): self._struct_name = struct_name self._struct_formats = struct_formats diff --git a/tests/scripts/safebufferaccess.py b/tests/scripts/safebufferaccess.py index af133cce..9e279e03 100755 --- a/tests/scripts/safebufferaccess.py +++ b/tests/scripts/safebufferaccess.py @@ -9,7 +9,6 @@ """ import argparse -import ctypes import numpy as np @@ -25,22 +24,22 @@ subparser = parser.add_subparsers(required=True, dest="subparser_name") load = subparser.add_parser("load", help="attempt to load some element") -load.add_argument("index", type=int, help="the index to load") +load.add_argument("index", type=np.intp, help="the index to load") store = subparser.add_parser("store", help="attempt to store into some element") -store.add_argument("index", type=int, help="the index to load") -store.add_argument("value", type=int, help="the value to store") +store.add_argument("index", type=np.intp, help="the index to load") +store.add_argument("value", type=np.int64, help="the value to store") args = parser.parse_args() -a = np.array(range(args.size), dtype=ctypes.c_int64) +a = np.array(range(args.size), dtype=np.int64) ab = NumpyBuffer(a) ab_safe = SafeBuffer(ab) ab_v = asm.Variable("a", ab_safe.ftype) ab_slt = asm.Slot("a_", ab_safe.ftype) -idx = asm.Variable("idx", ctypes.c_size_t) -val = asm.Variable("val", ctypes.c_int64) +idx = asm.Variable("idx", np.intp) +val = asm.Variable("val", np.int64) res_var = asm.Variable("val", ab_safe.ftype.element_type) res_var2 = asm.Variable("val2", ab_safe.ftype.element_type) @@ -64,6 +63,7 @@ res_var2, asm.Load(ab_slt, idx), ), + asm.Repack(ab_slt), asm.Return(res_var), ) ), @@ -79,7 +79,8 @@ idx, val, ), - asm.Return(asm.Literal(ctypes.c_int64(0))), + asm.Repack(ab_slt), + asm.Return(asm.Literal(0)), ) ), ), @@ -91,8 +92,8 @@ match args.subparser_name: case "load": - print(access(ab_safe, ctypes.c_size_t(args.index)).value) + print(access(ab_safe, args.index)) case "store": - change(ab_safe, ctypes.c_size_t(args.index), ctypes.c_int64(args.value)) + change(ab_safe, args.index, args.value) arr = [str(ab_safe.load(i)) for i in range(args.size)] print(f"[{' '.join(arr)}]") diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 202c406e..a06e22a5 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -260,7 +260,7 @@ def test_malloc_resize(new_size): ) ) mod = CCompiler()(prgm) - assert mod.length(ab).value == new_size + assert mod.length(ab) == new_size assert ab.length() == new_size for i in range(new_size): assert ab.load(i) == 0 if i >= len(a) else a[i]