Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions src/finchlite/codegen/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from .. import finch_assembly as asm
from ..algebra import query_property, register_property
from ..finch_assembly import AssemblyStructFType, BufferFType, TupleFType
from ..finch_assembly import (
AssemblyStructFType,
BufferFType,
ImmutableStructFType,
MutableStructFType,
TupleFType,
)
from ..symbolic import Context, Namespace, ScopedDict, fisinstance, ftype
from ..util import config
from ..util.cache import file_cache
Expand Down Expand Up @@ -140,7 +146,7 @@ def construct_from_c(fmt, c_obj):
return fmt.construct_from_c(c_obj)
try:
return query_property(fmt, "construct_from_c", "__attr__", c_obj)
except NotImplementedError:
except AttributeError:
return fmt(c_obj)


Expand Down Expand Up @@ -204,12 +210,14 @@ def construct_from_c(fmt, c_obj):
register_property(t, "construct_from_c", "__attr__", lambda fmt, c_value: c_value)
register_property(t, "numba_type", "__attr__", lambda t: t)


register_property(
np.generic,
"serialize_to_c",
"__attr__",
lambda fmt, obj: np.ctypeslib.as_ctypes(obj),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have it as np.ctypeslib.as_ctypes(np.array(obj)), instead? It looks that 0-d arrays print your example correctly. Then we shouldn't need scalar_to_ctypes_copy.

lambda fmt, obj: np.ctypeslib.as_ctypes(np.array(obj)),
)

# pass by value -> no op
register_property(
np.generic,
Expand Down Expand Up @@ -254,11 +262,9 @@ def __call__(self, *args):
self.argtypes, args, serial_args, strict=False
):
deserialize_from_c(type_, arg, serial_arg)
if hasattr(self.ret_type, "construct_from_c"):
return construct_from_c(res.ftype, res)
if self.ret_type is type(None):
return None
return self.ret_type(res)
return construct_from_c(self.ret_type, res)


class CModule:
Expand Down Expand Up @@ -315,7 +321,7 @@ def __call__(self, prgm):
for func in prgm.funcs:
match func:
case asm.Function(asm.Variable(func_name, return_t), args, _):
return_t = c_type(return_t)
# return_t = c_type(return_t)
arg_ts = [arg.result_format for arg in args]
kern = CKernel(getattr(lib, func_name), return_t, arg_ts)
kernels[func_name] = kern
Expand Down Expand Up @@ -1035,36 +1041,55 @@ def struct_c_type(fmt: AssemblyStructFType):
return new_struct


"""
Note: When serializing any struct to C, it will get serialized to a struct with
no indirection.

When you pass a struct into a kernel that expects a struct pointer, ctypes can
intelligently infer whether we are working with a pointer arg type (pass by
reference) or a non-pointer type (in which case it will immediately apply
indirection)
"""

register_property(
AssemblyStructFType,
MutableStructFType,
"c_type",
"__attr__",
lambda fmt: ctypes.POINTER(struct_c_type(fmt)),
)

register_property(
ImmutableStructFType, "c_type", "__attr__", lambda fmt: struct_c_type(fmt)
)

def struct_c_getattr(fmt: AssemblyStructFType, ctx, obj, attr):
return f"{obj}->{attr}"

register_property(
MutableStructFType,
"c_getattr",
"__attr__",
lambda fmt, ctx, obj, attr: f"{obj}->{attr}",
)

register_property(
AssemblyStructFType,
ImmutableStructFType,
"c_getattr",
"__attr__",
struct_c_getattr,
lambda fmt, ctx, obj, attr: f"{obj}.{attr}",
)


def struct_c_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val):
def struct_mutable_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val):
ctx.emit(f"{ctx.feed}{obj}->{attr} = {val};")
return


# the equivalent for immutable is f"{ctx.feed}{obj}.{attr} = {val};"
# but we will not include that because it's bad.

register_property(
AssemblyStructFType,
MutableStructFType,
"c_setattr",
"__attr__",
struct_c_setattr,
struct_mutable_setattr,
)


Expand Down Expand Up @@ -1092,18 +1117,23 @@ def serialize_tuple_to_c(fmt, obj):
"__attr__",
serialize_tuple_to_c,
)


def tuple_construct_from_c(fmt: TupleFType, c_struct):
args = [getattr(c_struct, name) for name in fmt.struct_fieldnames]
return tuple(args)


register_property(
TupleFType,
"construct_from_c",
"__attr__",
lambda fmt, obj, c_tuple: tuple(c_tuple),
tuple_construct_from_c,
)

register_property(
TupleFType,
"c_type",
"__attr__",
lambda fmt: ctypes.POINTER(
struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields))
),
lambda fmt: struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields)),
)
16 changes: 10 additions & 6 deletions src/finchlite/finch_assembly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .buffer import Buffer, BufferFType, element_type, length_type
from .cfg_builder import (
AssemblyCFGBuilder,
assembly_build_cfg,
assembly_number_uses,
)
from .cfg_builder import AssemblyCFGBuilder, assembly_build_cfg, assembly_number_uses
from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation
from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel
from .nodes import (
Expand Down Expand Up @@ -36,7 +32,13 @@
Variable,
WhileLoop,
)
from .struct import AssemblyStructFType, NamedTupleFType, TupleFType
from .struct import (
AssemblyStructFType,
ImmutableStructFType,
MutableStructFType,
NamedTupleFType,
TupleFType,
)
from .type_checker import AssemblyTypeChecker, AssemblyTypeError, assembly_check_types

__all__ = [
Expand All @@ -61,10 +63,12 @@
"GetAttr",
"If",
"IfElse",
"ImmutableStructFType",
"Length",
"Literal",
"Load",
"Module",
"MutableStructFType",
"NamedTupleFType",
"Print",
"Repack",
Expand Down
21 changes: 19 additions & 2 deletions src/finchlite/finch_assembly/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,24 @@ def struct_attrtype(self, attr: str) -> Any:
return dict(self.struct_fields)[attr]


class NamedTupleFType(AssemblyStructFType):
class ImmutableStructFType(AssemblyStructFType):
@property
def is_mutable(self) -> bool:
return False


class MutableStructFType(AssemblyStructFType):
"""
Class for a mutable assembly struct type.
It is currently not used anywhere, but maybe it will be useful in the future?
"""

@property
def is_mutable(self) -> bool:
return True


class NamedTupleFType(ImmutableStructFType):
def __init__(self, struct_name, struct_fields):
self._struct_name = struct_name
self._struct_fields = struct_fields
Expand Down Expand Up @@ -79,7 +96,7 @@ def __call__(self, *args):
return namedtuple(self.struct_name, self.struct_fieldnames)(args)


class TupleFType(AssemblyStructFType):
class TupleFType(ImmutableStructFType):
def __init__(self, struct_name, struct_formats):
self._struct_name = struct_name
self._struct_formats = struct_formats
Expand Down
21 changes: 11 additions & 10 deletions tests/scripts/safebufferaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import argparse
import ctypes

import numpy as np

Expand All @@ -25,22 +24,22 @@
subparser = parser.add_subparsers(required=True, dest="subparser_name")

load = subparser.add_parser("load", help="attempt to load some element")
load.add_argument("index", type=int, help="the index to load")
load.add_argument("index", type=np.intp, help="the index to load")

store = subparser.add_parser("store", help="attempt to store into some element")
store.add_argument("index", type=int, help="the index to load")
store.add_argument("value", type=int, help="the value to store")
store.add_argument("index", type=np.intp, help="the index to load")
store.add_argument("value", type=np.int64, help="the value to store")

args = parser.parse_args()


a = np.array(range(args.size), dtype=ctypes.c_int64)
a = np.array(range(args.size), dtype=np.int64)
ab = NumpyBuffer(a)
ab_safe = SafeBuffer(ab)
ab_v = asm.Variable("a", ab_safe.ftype)
ab_slt = asm.Slot("a_", ab_safe.ftype)
idx = asm.Variable("idx", ctypes.c_size_t)
val = asm.Variable("val", ctypes.c_int64)
idx = asm.Variable("idx", np.intp)
val = asm.Variable("val", np.int64)

res_var = asm.Variable("val", ab_safe.ftype.element_type)
res_var2 = asm.Variable("val2", ab_safe.ftype.element_type)
Expand All @@ -64,6 +63,7 @@
res_var2,
asm.Load(ab_slt, idx),
),
asm.Repack(ab_slt),
asm.Return(res_var),
)
),
Expand All @@ -79,7 +79,8 @@
idx,
val,
),
asm.Return(asm.Literal(ctypes.c_int64(0))),
asm.Repack(ab_slt),
asm.Return(asm.Literal(0)),
)
),
),
Expand All @@ -91,8 +92,8 @@

match args.subparser_name:
case "load":
print(access(ab_safe, ctypes.c_size_t(args.index)).value)
print(access(ab_safe, args.index))
case "store":
change(ab_safe, ctypes.c_size_t(args.index), ctypes.c_int64(args.value))
change(ab_safe, args.index, args.value)
arr = [str(ab_safe.load(i)) for i in range(args.size)]
print(f"[{' '.join(arr)}]")
2 changes: 1 addition & 1 deletion tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_malloc_resize(new_size):
)
)
mod = CCompiler()(prgm)
assert mod.length(ab).value == new_size
assert mod.length(ab) == new_size
assert ab.length() == new_size
for i in range(new_size):
assert ab.load(i) == 0 if i >= len(a) else a[i]
Expand Down