-
Notifications
You must be signed in to change notification settings - Fork 8
Dense Level #183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Dense Level #183
Changes from all commits
2c9fb85
b45493b
ca5db23
4f45e9a
7f09812
06ef043
2f88deb
7530d0c
9746a1c
4ad31e1
f62eda6
34b3414
9e62bb8
eefd2ee
8e451e0
ff366f6
2aae134
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,8 +14,7 @@ | |
| return_type, | ||
| ) | ||
| from ..codegen import NumpyBufferFType | ||
| from ..compile import BufferizedNDArrayFType, ExtentFType, dimension | ||
| from ..finch_assembly import TupleFType | ||
| from ..compile import ExtentFType, dimension | ||
| from ..finch_logic import ( | ||
| Aggregate, | ||
| Alias, | ||
|
|
@@ -214,11 +213,7 @@ def __call__( | |
| return ntn.Assign( | ||
| ntn.Variable( | ||
| name, | ||
| BufferizedNDArrayFType( | ||
| NumpyBufferFType(val.dtype), | ||
| val.ndim, | ||
| TupleFType.from_tuple(val.shape_type), | ||
| ), | ||
| val.from_kwargs(val.to_kwargs()), | ||
| ), | ||
| compile_logic_constant(tns), | ||
| ) | ||
|
|
@@ -498,13 +493,19 @@ def find_suitable_rep(root, table_vars) -> TensorFType: | |
| ) | ||
| ) | ||
|
|
||
| return BufferizedNDArrayFType( | ||
| buf_t=NumpyBufferFType(dtype), | ||
| # TODO: properly infer result rep from args | ||
| result_rep, fields = args_suitable_reps_fields[0] | ||
| levels_to_add = [ | ||
| idx for idx, f in enumerate(result_fields) if f not in fields | ||
| ] | ||
| result_rep = result_rep.add_levels(levels_to_add) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As in the comment above - temporary selection of format (including removing/adding levels), this will be moved to a separate module. |
||
| kwargs = result_rep.to_kwargs() | ||
| kwargs.update( | ||
| element_type=NumpyBufferFType(dtype), | ||
| ndim=np.intp(len(result_fields)), | ||
| strides_t=TupleFType.from_tuple( | ||
| tuple(field_type_map[f] for f in result_fields) | ||
| ), | ||
| dimension_type=tuple(field_type_map[f] for f in result_fields), | ||
| ) | ||
| return result_rep.from_kwargs(**kwargs) | ||
| case Aggregate(Literal(op), init, arg, idxs): | ||
| init_suitable_rep = find_suitable_rep(init, table_vars) | ||
| arg_suitable_rep = find_suitable_rep(arg, table_vars) | ||
|
|
@@ -513,16 +514,24 @@ def find_suitable_rep(root, table_vars) -> TensorFType: | |
| op, init_suitable_rep.element_type, arg_suitable_rep.element_type | ||
| ) | ||
| ) | ||
| strides_t = tuple( | ||
| st | ||
| for f, st in zip(arg.fields, arg_suitable_rep.shape_type, strict=True) | ||
| if f not in idxs | ||
| ) | ||
| return BufferizedNDArrayFType( | ||
| buf_t=buf_t, | ||
| # TODO: properly infer result rep from args | ||
| levels_to_remove = [] | ||
| strides_t = [] | ||
| for idx, (f, st) in enumerate( | ||
| zip(arg.fields, arg_suitable_rep.shape_type, strict=True) | ||
| ): | ||
| if f not in idxs: | ||
| strides_t.append(st) | ||
| else: | ||
| levels_to_remove.append(idx) | ||
| arg_suitable_rep = arg_suitable_rep.remove_levels(levels_to_remove) | ||
| kwargs = arg_suitable_rep.to_kwargs() | ||
| kwargs.update( | ||
| buffer_type=buf_t, | ||
| ndim=np.intp(len(strides_t)), | ||
| strides_t=TupleFType.from_tuple(strides_t), | ||
| dimension_type=tuple(strides_t), | ||
| ) | ||
| return arg_suitable_rep.from_kwargs(**kwargs) | ||
| case LogicTree() as tree: | ||
| for child in tree.children: | ||
| suitable_rep = find_suitable_rep(child, table_vars) | ||
|
|
@@ -555,11 +564,13 @@ class LogicCompiler: | |
| def __init__(self): | ||
| self.ll = LogicLowerer() | ||
|
|
||
| def __call__(self, prgm: LogicNode) -> tuple[ntn.NotationNode, dict[Alias, Table]]: | ||
| def __call__( | ||
| self, prgm: LogicNode | ||
| ) -> tuple[ntn.NotationNode, dict[Alias, ntn.Variable], dict[Alias, Table]]: | ||
| prgm, table_vars, slot_vars, dim_size_vars, tables, field_relabels = ( | ||
| record_tables(prgm) | ||
| ) | ||
| lowered_prgm = self.ll( | ||
| prgm, table_vars, slot_vars, dim_size_vars, field_relabels | ||
| ) | ||
| return merge_blocks(lowered_prgm), tables | ||
| return merge_blocks(lowered_prgm), table_vars, tables | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -716,11 +716,11 @@ def struct_numba_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): | |
|
|
||
|
|
||
| def struct_construct_from_numba(fmt: AssemblyStructFType, numba_struct): | ||
| args = [ | ||
| construct_from_numba(field_type, getattr(numba_struct, name)) | ||
| kwargs = { | ||
| name: construct_from_numba(field_type, getattr(numba_struct, name)) | ||
| for (name, field_type) in fmt.struct_fields | ||
| ] | ||
| return fmt(*args) | ||
| } | ||
| return fmt(**kwargs) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. formats now have "multiple constructors", whether we construct from Numba or in the facing API |
||
|
|
||
|
|
||
| register_property( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proper handling of the format inference is a separate larger task (to map
traits.jl), here I only convert it to a "dict of attributes" where I can override some of them (for example promoted dtype from two inputs).