Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions src/finchlite/codegen/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def construct_from_c(fmt, c_obj):
return fmt.construct_from_c(c_obj)
try:
return query_property(fmt, "construct_from_c", "__attr__", c_obj)
except NotImplementedError:
except AttributeError:
return fmt(c_obj)


Expand Down Expand Up @@ -204,12 +204,24 @@ def construct_from_c(fmt, c_obj):
register_property(t, "construct_from_c", "__attr__", lambda fmt, c_value: c_value)
register_property(t, "numba_type", "__attr__", lambda t: t)


def scalar_to_ctypes_copy(fmt, obj):
"""
This hack is required because it turns out that scalars don't own memory or smth
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reported it here: numpy/numpy#30354

"""
arr = np.array([obj], dtype=obj.dtype, copy=True)
scalar_ctype = np.ctypeslib.as_ctypes_type(obj.dtype)
ptr_ctype = ctypes.POINTER(scalar_ctype)
return arr.ctypes.data_as(ptr_ctype).contents


register_property(
np.generic,
"serialize_to_c",
"__attr__",
lambda fmt, obj: np.ctypeslib.as_ctypes(obj),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have it as np.ctypeslib.as_ctypes(np.array(obj)), instead? It looks that 0-d arrays print your example correctly. Then we shouldn't need scalar_to_ctypes_copy.

scalar_to_ctypes_copy,
)

# pass by value -> no op
register_property(
np.generic,
Expand Down Expand Up @@ -254,11 +266,9 @@ def __call__(self, *args):
self.argtypes, args, serial_args, strict=False
):
deserialize_from_c(type_, arg, serial_arg)
if hasattr(self.ret_type, "construct_from_c"):
return construct_from_c(res.ftype, res)
if self.ret_type is type(None):
return None
return self.ret_type(res)
return construct_from_c(self.ret_type, res)


class CModule:
Expand Down Expand Up @@ -315,7 +325,7 @@ def __call__(self, prgm):
for func in prgm.funcs:
match func:
case asm.Function(asm.Variable(func_name, return_t), args, _):
return_t = c_type(return_t)
# return_t = c_type(return_t)
arg_ts = [arg.result_format for arg in args]
kern = CKernel(getattr(lib, func_name), return_t, arg_ts)
kernels[func_name] = kern
Expand Down Expand Up @@ -1035,16 +1045,35 @@ def struct_c_type(fmt: AssemblyStructFType):
return new_struct


def struct_c_type_wrapper(fmt: AssemblyStructFType):
"""
C type decider for struct types. Serialization actually ensures that before
crossing the FFI boundary, all serialized structs are structs, not
pointers.

The reason why we have this method is that ctypes can intelligently infer
whether we are working with a pointer arg type (pass by reference) or a
non-pointer type (pass by value)
"""
t = struct_c_type(fmt)
if fmt.is_mutable:
return ctypes.POINTER(t)
return t


register_property(
AssemblyStructFType,
"c_type",
"__attr__",
lambda fmt: ctypes.POINTER(struct_c_type(fmt)),
struct_c_type_wrapper,
)


def struct_c_getattr(fmt: AssemblyStructFType, ctx, obj, attr):
return f"{obj}->{attr}"
if fmt.is_mutable:
# we are passing things in as a pointer (reference c_type_wrapper)
return f"{obj}->{attr}"
return f"{obj}.{attr}"


register_property(
Expand All @@ -1056,8 +1085,10 @@ def struct_c_getattr(fmt: AssemblyStructFType, ctx, obj, attr):


def struct_c_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val):
ctx.emit(f"{ctx.feed}{obj}->{attr} = {val};")
return
if fmt.is_mutable:
ctx.emit(f"{ctx.feed}{obj}->{attr} = {val};")
else:
ctx.emit(f"{ctx.feed}{obj}.{attr} = {val};")


register_property(
Expand Down Expand Up @@ -1092,18 +1123,23 @@ def serialize_tuple_to_c(fmt, obj):
"__attr__",
serialize_tuple_to_c,
)


def tuple_construct_from_c(fmt: TupleFType, c_struct):
args = [getattr(c_struct, name) for name in fmt.struct_fieldnames]
return tuple(args)


register_property(
TupleFType,
"construct_from_c",
"__attr__",
lambda fmt, obj, c_tuple: tuple(c_tuple),
tuple_construct_from_c,
)

register_property(
TupleFType,
"c_type",
"__attr__",
lambda fmt: ctypes.POINTER(
struct_c_type(asm.NamedTupleFType("CTuple", fmt.struct_fields))
),
lambda fmt: struct_c_type_wrapper(asm.NamedTupleFType("CTuple", fmt.struct_fields)),
)
21 changes: 11 additions & 10 deletions tests/scripts/safebufferaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import argparse
import ctypes

import numpy as np

Expand All @@ -25,22 +24,22 @@
subparser = parser.add_subparsers(required=True, dest="subparser_name")

load = subparser.add_parser("load", help="attempt to load some element")
load.add_argument("index", type=int, help="the index to load")
load.add_argument("index", type=np.intp, help="the index to load")

store = subparser.add_parser("store", help="attempt to store into some element")
store.add_argument("index", type=int, help="the index to load")
store.add_argument("value", type=int, help="the value to store")
store.add_argument("index", type=np.intp, help="the index to load")
store.add_argument("value", type=np.int64, help="the value to store")

args = parser.parse_args()


a = np.array(range(args.size), dtype=ctypes.c_int64)
a = np.array(range(args.size), dtype=np.int64)
ab = NumpyBuffer(a)
ab_safe = SafeBuffer(ab)
ab_v = asm.Variable("a", ab_safe.ftype)
ab_slt = asm.Slot("a_", ab_safe.ftype)
idx = asm.Variable("idx", ctypes.c_size_t)
val = asm.Variable("val", ctypes.c_int64)
idx = asm.Variable("idx", np.intp)
val = asm.Variable("val", np.int64)

res_var = asm.Variable("val", ab_safe.ftype.element_type)
res_var2 = asm.Variable("val2", ab_safe.ftype.element_type)
Expand All @@ -64,6 +63,7 @@
res_var2,
asm.Load(ab_slt, idx),
),
asm.Repack(ab_slt),
asm.Return(res_var),
)
),
Expand All @@ -79,7 +79,8 @@
idx,
val,
),
asm.Return(asm.Literal(ctypes.c_int64(0))),
asm.Repack(ab_slt),
asm.Return(asm.Literal(0)),
)
),
),
Expand All @@ -91,8 +92,8 @@

match args.subparser_name:
case "load":
print(access(ab_safe, ctypes.c_size_t(args.index)).value)
print(access(ab_safe, args.index))
case "store":
change(ab_safe, ctypes.c_size_t(args.index), ctypes.c_int64(args.value))
change(ab_safe, args.index, args.value)
arr = [str(ab_safe.load(i)) for i in range(args.size)]
print(f"[{' '.join(arr)}]")
2 changes: 1 addition & 1 deletion tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_malloc_resize(new_size):
)
)
mod = CCompiler()(prgm)
assert mod.length(ab).value == new_size
assert mod.length(ab) == new_size
assert ab.length() == new_size
for i in range(new_size):
assert ab.load(i) == 0 if i >= len(a) else a[i]
Expand Down