diff --git a/src/finchlite/__init__.py b/src/finchlite/__init__.py index ec456651..3cc61683 100644 --- a/src/finchlite/__init__.py +++ b/src/finchlite/__init__.py @@ -135,6 +135,9 @@ DenseLevelFType, ElementLevelFType, FiberTensorFType, + dense, + element, + fiber_tensor, ) __all__ = [ @@ -188,10 +191,12 @@ "cos", "cosh", "defer", + "dense", "dimension", "divide", "einop", "einsum", + "element", "element_type", "elementwise", "equal", @@ -199,6 +204,7 @@ "expand_dims", "expm1", "extent", + "fiber_tensor", "fill_value", "fisinstance", "flatten", diff --git a/src/finchlite/algebra/tensor.py b/src/finchlite/algebra/tensor.py index aeeaa24b..4d901421 100644 --- a/src/finchlite/algebra/tensor.py +++ b/src/finchlite/algebra/tensor.py @@ -34,6 +34,27 @@ def shape_type(self) -> tuple[type, ...]: e.g. dtypes, formats, or types, and so that we can easily index it.""" ... + @abstractmethod + def __init__(self, *args): + """TensorFType instance initializer.""" + ... + + # TODO: Remove and properly infer result rep + def add_levels(self, idxs: list[int]): + raise Exception("TODO: to remove") + + # TODO: Remove and properly infer result rep + def remove_levels(self, idxs: list[int]): + raise Exception("TODO: to remove") + + # TODO: Remove and properly infer result rep + def to_kwargs(self) -> dict[str, Any]: + raise Exception("TODO: to remove") + + # TODO: Remove and properly infer result rep + def from_kwargs(self, **kwargs): + raise Exception("TODO: to remove") + class Tensor(FTyped, ABC): """ diff --git a/src/finchlite/autoschedule/compiler.py b/src/finchlite/autoschedule/compiler.py index f0f1ddb0..c997df89 100644 --- a/src/finchlite/autoschedule/compiler.py +++ b/src/finchlite/autoschedule/compiler.py @@ -14,8 +14,7 @@ return_type, ) from ..codegen import NumpyBufferFType -from ..compile import BufferizedNDArrayFType, ExtentFType, dimension -from ..finch_assembly import TupleFType +from ..compile import ExtentFType, dimension from ..finch_logic import ( Aggregate, Alias, @@ -214,11 +213,7 @@ def __call__( return ntn.Assign( ntn.Variable( name, - BufferizedNDArrayFType( - NumpyBufferFType(val.dtype), - val.ndim, - TupleFType.from_tuple(val.shape_type), - ), + val.from_kwargs(val.to_kwargs()), ), compile_logic_constant(tns), ) @@ -498,13 +493,19 @@ def find_suitable_rep(root, table_vars) -> TensorFType: ) ) - return BufferizedNDArrayFType( - buf_t=NumpyBufferFType(dtype), + # TODO: properly infer result rep from args + result_rep, fields = args_suitable_reps_fields[0] + levels_to_add = [ + idx for idx, f in enumerate(result_fields) if f not in fields + ] + result_rep = result_rep.add_levels(levels_to_add) + kwargs = result_rep.to_kwargs() + kwargs.update( + element_type=NumpyBufferFType(dtype), ndim=np.intp(len(result_fields)), - strides_t=TupleFType.from_tuple( - tuple(field_type_map[f] for f in result_fields) - ), + dimension_type=tuple(field_type_map[f] for f in result_fields), ) + return result_rep.from_kwargs(**kwargs) case Aggregate(Literal(op), init, arg, idxs): init_suitable_rep = find_suitable_rep(init, table_vars) arg_suitable_rep = find_suitable_rep(arg, table_vars) @@ -513,16 +514,24 @@ def find_suitable_rep(root, table_vars) -> TensorFType: op, init_suitable_rep.element_type, arg_suitable_rep.element_type ) ) - strides_t = tuple( - st - for f, st in zip(arg.fields, arg_suitable_rep.shape_type, strict=True) - if f not in idxs - ) - return BufferizedNDArrayFType( - buf_t=buf_t, + # TODO: properly infer result rep from args + levels_to_remove = [] + strides_t = [] + for idx, (f, st) in enumerate( + zip(arg.fields, arg_suitable_rep.shape_type, strict=True) + ): + if f not in idxs: + strides_t.append(st) + else: + levels_to_remove.append(idx) + arg_suitable_rep = arg_suitable_rep.remove_levels(levels_to_remove) + kwargs = arg_suitable_rep.to_kwargs() + kwargs.update( + buffer_type=buf_t, ndim=np.intp(len(strides_t)), - strides_t=TupleFType.from_tuple(strides_t), + dimension_type=tuple(strides_t), ) + return arg_suitable_rep.from_kwargs(**kwargs) case LogicTree() as tree: for child in tree.children: suitable_rep = find_suitable_rep(child, table_vars) @@ -555,11 +564,13 @@ class LogicCompiler: def __init__(self): self.ll = LogicLowerer() - def __call__(self, prgm: LogicNode) -> tuple[ntn.NotationNode, dict[Alias, Table]]: + def __call__( + self, prgm: LogicNode + ) -> tuple[ntn.NotationNode, dict[Alias, ntn.Variable], dict[Alias, Table]]: prgm, table_vars, slot_vars, dim_size_vars, tables, field_relabels = ( record_tables(prgm) ) lowered_prgm = self.ll( prgm, table_vars, slot_vars, dim_size_vars, field_relabels ) - return merge_blocks(lowered_prgm), tables + return merge_blocks(lowered_prgm), table_vars, tables diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 77b7fa59..74a7987c 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -29,9 +29,9 @@ def compile_plan( return None case lgc.Query( lgc.Alias(name), - lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _), + lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _) as agg, ): - einidxs = tuple(ein.Index(field.name) for field in node.rhs.fields) + einidxs = tuple(ein.Index(field.name) for field in agg.fields) my_bodies = [] if init != init_value(operation, type(init)): my_bodies.append( @@ -52,11 +52,12 @@ def compile_plan( ) return ein.Plan(tuple(my_bodies)) case lgc.Query(lgc.Alias(name), rhs): + assert isinstance(rhs, lgc.LogicExpression) einarg = self.compile_operand(rhs) return ein.Einsum( op=ein.Literal(overwrite), tns=ein.Alias(name), - idxs=tuple(ein.Index(field.name) for field in node.rhs.fields), + idxs=tuple(ein.Index(field.name) for field in rhs.fields), arg=einarg, ) diff --git a/src/finchlite/codegen/numba_backend.py b/src/finchlite/codegen/numba_backend.py index 9ff9f8e9..ca2d41c8 100644 --- a/src/finchlite/codegen/numba_backend.py +++ b/src/finchlite/codegen/numba_backend.py @@ -716,11 +716,11 @@ def struct_numba_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val): def struct_construct_from_numba(fmt: AssemblyStructFType, numba_struct): - args = [ - construct_from_numba(field_type, getattr(numba_struct, name)) + kwargs = { + name: construct_from_numba(field_type, getattr(numba_struct, name)) for (name, field_type) in fmt.struct_fields - ] - return fmt(*args) + } + return fmt(**kwargs) register_property( diff --git a/src/finchlite/codegen/numpy_buffer.py b/src/finchlite/codegen/numpy_buffer.py index 0e3f308b..140a638d 100644 --- a/src/finchlite/codegen/numpy_buffer.py +++ b/src/finchlite/codegen/numpy_buffer.py @@ -76,6 +76,10 @@ def __str__(self): arr_str = str(self.arr).replace("\n", "") return f"np_buf({arr_str})" + def __repr__(self): + arr_repr = repr(self.arr).replace("\n", "") + return f"NumpyBuffer({arr_repr})" + class NumpyBufferFType(CBufferFType, NumbaBufferFType, CStackFType): """ diff --git a/src/finchlite/compile/bufferized_ndarray.py b/src/finchlite/compile/bufferized_ndarray.py index 531f5758..d14500ea 100644 --- a/src/finchlite/compile/bufferized_ndarray.py +++ b/src/finchlite/compile/bufferized_ndarray.py @@ -16,24 +16,24 @@ class BufferizedNDArray(Tensor): def __init__( self, - arr: np.ndarray | NumpyBuffer, + val: np.ndarray | NumpyBuffer, shape: tuple[np.integer, ...] | None = None, strides: tuple[np.integer, ...] | None = None, ): self._shape: tuple[np.integer, ...] self.strides: tuple[np.integer, ...] - if shape is None and strides is None and isinstance(arr, np.ndarray): - itemsize = arr.dtype.itemsize - for stride in arr.strides: + if shape is None and strides is None and isinstance(val, np.ndarray): + itemsize = val.dtype.itemsize + for stride in val.strides: if stride % itemsize != 0: raise ValueError("Array must be aligned to multiple of itemsize") - self.strides = tuple(np.intp(stride // itemsize) for stride in arr.strides) - self._shape = tuple(np.intp(s) for s in arr.shape) - self.buf = NumpyBuffer(arr.reshape(-1, copy=False)) - elif shape is not None and strides is not None and isinstance(arr, NumpyBuffer): + self.strides = tuple(np.intp(stride // itemsize) for stride in val.strides) + self._shape = tuple(np.intp(s) for s in val.shape) + self.val = NumpyBuffer(val.reshape(-1, copy=False)) + elif shape is not None and strides is not None and isinstance(val, NumpyBuffer): self.strides = strides self._shape = shape - self.buf = arr + self.val = val else: raise Exception("Invalid constructor arguments") @@ -42,14 +42,18 @@ def to_numpy(self): Convert the bufferized NDArray to a NumPy array. This is used to get the underlying NumPy array from the bufferized NDArray. """ - return self.buf.arr.reshape(self._shape, copy=False) + return self.val.arr.reshape(self._shape, copy=False) @property def ftype(self): """ Returns the ftype of the buffer, which is a BufferizedNDArrayFType. """ - return BufferizedNDArrayFType(ftype(self.buf), self.ndim, ftype(self.strides)) + return BufferizedNDArrayFType( + buffer_type=ftype(self.val), + ndim=self.ndim, + dimension_type=ftype(self.strides), + ) @property def shape(self): @@ -74,8 +78,8 @@ def declare(self, init, op, shape): raise ValueError( f"Invalid dimension end value {dim.end} for ndarray declaration." ) - for i in range(self.buf.length()): - self.buf.store(i, init) + for i in range(self.val.length()): + self.val.store(i, init) return self def freeze(self, op): @@ -94,7 +98,7 @@ def __getitem__(self, index): """ if isinstance(index, tuple): index = 0 if index == () else np.dot(index, self.strides) - return self.buf.load(index) + return self.val.load(index) def __setitem__(self, index, value): """ @@ -103,7 +107,7 @@ def __setitem__(self, index, value): """ if isinstance(index, tuple): index = np.ravel_multi_index(index, self._shape) - self.buf.store(index, value) + self.val.store(index, value) def __str__(self): return f"BufferizedNDArray(shape={self.shape})" @@ -141,24 +145,30 @@ def str_format(types): @property def struct_fields(self): return [ - ("buf", self.buf_t), + ("val", self.buf_t), ("shape", self.shape_t), ("strides", self.strides_t), ] - def __init__(self, buf_t: NumpyBufferFType, ndim: np.intp, strides_t: TupleFType): - self.buf_t = buf_t + def __init__( + self, + *, + buffer_type: NumpyBufferFType, + ndim: np.intp, + dimension_type: TupleFType, + ): + self.buf_t = buffer_type self._ndim = ndim - self.shape_t = strides_t # assuming shape is the same type as strides - self.strides_t = strides_t + self.shape_t = dimension_type # assuming shape is the same type as strides + self.strides_t = dimension_type def __eq__(self, other): if not isinstance(other, BufferizedNDArrayFType): return False - return self.buf_t == other.buf_t and self._ndim == other._ndim + return self.buf_t == other.buf_t and self.ndim == other.ndim def __hash__(self): - return hash((self.buf_t, self._ndim)) + return hash((self.buf_t, self.ndim)) def __str__(self): return str(self.struct_name) @@ -170,6 +180,35 @@ def __repr__(self): def ndim(self) -> np.intp: return self._ndim + @ndim.setter + def ndim(self, val): + self._ndim = val + + def from_kwargs(self, **kwargs) -> "BufferizedNDArrayFType": + b_t = kwargs.get("buffer_type", self.buf_t) + ndim = kwargs.get("ndim", self.ndim) + if "shape_type" in kwargs: + s_t = kwargs["shape_type"] + d_t = s_t if isinstance(s_t, TupleFType) else TupleFType.from_tuple(s_t) + else: + d_t = self.shape_t + return BufferizedNDArrayFType(buffer_type=b_t, ndim=ndim, dimension_type=d_t) + + def to_kwargs(self): + return { + "buffer_type": self.buf_t, + "ndim": self.ndim, + "shape_type": self.shape_t, + } + + # TODO: temporary approach for suitable rep and traits + def add_levels(self, idxs: list[int]): + return self + + # TODO: temporary approach for suitable rep and traits + def remove_levels(self, idxs: list[int]): + return self + @property def fill_value(self) -> Any: return np.zeros((), dtype=self.buf_t.element_type)[()] @@ -180,7 +219,7 @@ def element_type(self): @property def shape_type(self) -> tuple: - return tuple(np.intp for _ in range(self._ndim)) + return tuple(np.intp for _ in range(self.ndim)) def lower_declare(self, ctx, tns, init, op, shape): i_var = asm.Variable("i", self.buf_t.length_type) @@ -220,14 +259,14 @@ def asm_unpack(self, ctx, var_n, val): Unpack the into asm context. """ stride = [] - for i in range(self._ndim): + for i in range(self.ndim): stride_i = asm.Variable(f"{var_n}_stride_{i}", self.buf_t.length_type) stride.append(stride_i) stride_e = asm.GetAttr(val, asm.Literal("strides")) stride_i_e = asm.GetAttr(stride_e, asm.Literal(f"element_{i}")) ctx.exec(asm.Assign(stride_i, stride_i_e)) buf = asm.Variable(f"{var_n}_buf", self.buf_t) - buf_e = asm.GetAttr(val, asm.Literal("buf")) + buf_e = asm.GetAttr(val, asm.Literal("val")) ctx.exec(asm.Assign(buf, buf_e)) buf_s = asm.Slot(f"{var_n}_buf_slot", self.buf_t) ctx.exec(asm.Unpack(buf_s, buf)) @@ -243,11 +282,11 @@ def asm_repack(self, ctx, lhs, obj): def __call__( self, - buf: NumpyBuffer, - shape: tuple[np.integer, ...], - strides: tuple[np.integer, ...], + val: NumpyBuffer, + shape: tuple[np.integer, ...] | None = None, + strides: tuple[np.integer, ...] | None = None, ) -> BufferizedNDArray: - return BufferizedNDArray(buf, shape, strides) + return BufferizedNDArray(val, shape, strides) class BufferizedNDArrayAccessor(Tensor): @@ -288,7 +327,7 @@ def unwrap(self): This is used to get the original tensor from a tensor view. """ assert self.ndim == 0, "Cannot unwrap a tensor view with non-zero dimension." - return self.tns.buf.load(self.pos) + return self.tns.val.load(self.pos) def increment(self, val): """ @@ -298,7 +337,7 @@ def increment(self, val): if self.op is None: raise ValueError("No operation defined for increment.") assert self.ndim == 0, "Cannot unwrap a tensor view with non-zero dimension." - self.tns.buf.store(self.pos, self.op(self.tns.buf.load(self.pos), val)) + self.tns.val.store(self.pos, self.op(self.tns.val.load(self.pos), val)) return self @@ -377,7 +416,7 @@ def asm_repack(self, ctx, lhs, obj): """ Repack the buffer from C context. """ - (self.tns.asm_repack(ctx, lhs.tns, obj.tns),) + self.tns.asm_repack(ctx, lhs.tns, obj.tns) ctx.exec( asm.Block( asm.SetAttr(lhs, "tns", obj.tns), diff --git a/src/finchlite/compile/lower.py b/src/finchlite/compile/lower.py index 90655040..3ed337b1 100644 --- a/src/finchlite/compile/lower.py +++ b/src/finchlite/compile/lower.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pprint import pprint from typing import Any @@ -15,14 +15,14 @@ class FinchTensorFType(TensorFType, ABC): @abstractmethod - def lower_unwrap(tns): + def lower_unwrap(self, ctx, obj): """ Unwrap a tensor view to get the underlying tensor. This is used to get the original tensor from a tensor view. """ @abstractmethod - def lower_increment(tns, val): + def lower_increment(self, ctx, obj, val): """ Increment a tensor view with an operation and value. This updates the tensor at the specified index with the operation and value. @@ -47,7 +47,10 @@ def lower_thaw(self, ctx, tns, op): """ @abstractmethod - def unfurl(self, ctx, tns, ext, mode, proto): ... + def unfurl(self, ctx, tns, ext, mode, proto): + """ + Unfurl a tensor. + """ @dataclass(eq=True, frozen=True) @@ -158,6 +161,14 @@ def struct_name(self): def struct_fields(self): return [("start", np.intp), ("end", np.intp)] + def from_kwargs(self, **kwargs) -> "ExtentFType": + start = kwargs.get("start", self.start) + end = kwargs.get("end", self.end) + return ExtentFType(start, end) # type: ignore[abstract] + + def to_kwargs(self): + return asdict(self) + def __call__(self, *args): raise TypeError(f"{self.struct_name} is not callable") @@ -281,6 +292,7 @@ def __init__( epilogue=None, bindings=None, slots=None, + access_modes=None, types=None, func_state=None, ): @@ -289,10 +301,13 @@ def __init__( bindings = ScopedDict() if slots is None: slots = ScopedDict() + if access_modes is None: + access_modes = ScopedDict() if types is None: types = ScopedDict() self.bindings = bindings self.slots = slots + self.access_modes = access_modes self.types = types self.func_state = func_state @@ -304,6 +319,7 @@ def block(self): blk = super().block() blk.bindings = self.bindings blk.slots = self.slots + blk.access_modes = self.access_modes blk.types = self.types blk.func_state = self.func_state return blk @@ -315,6 +331,7 @@ def scope(self): blk = self.block() blk.bindings = self.bindings.scope() blk.slots = self.slots.scope() + blk.access_modes = self.access_modes.scope() blk.types = self.types.scope() return blk @@ -340,6 +357,21 @@ def resolve(self, node): case _: raise ValueError(f"Expected Slot or Stack, got: {type(node)}") + def _freeze_tensor(self, tns_var: str, op: ntn.Literal | None) -> None: + if op is None: + assert tns_var not in self.access_modes + else: + assert self.access_modes[tns_var] == ntn.Update(op) + self.access_modes[tns_var] = ntn.Read() + + def _thaw_tensor(self, tns_var: str, op: ntn.Literal) -> None: + assert self.access_modes[tns_var] == ntn.Read() + self.access_modes[tns_var] = ntn.Update(op) + + def _rm_tensor_from_accesses(self, tns_var: str) -> None: + assert self.access_modes[tns_var] == ntn.Read() + del self.access_modes[tns_var] + def __call__(self, prgm): """ Lower Finch Notation to Finch Assembly. First we check for early @@ -387,6 +419,7 @@ def __call__(self, prgm): var = asm.Variable(var_n, var_t) self.exec(asm.Assign(var, val_code)) self.types[var_n] = var_t + self._freeze_tensor(var_n, op=None) self.slots[var_n] = var_t.asm_unpack( self, var_n, asm.Variable(var_n, var_t) ) @@ -397,6 +430,7 @@ def __call__(self, prgm): if var_t != self.types[var_n]: raise TypeError(f"Type mismatch: {var_t} != {self.types[var_n]}") obj = self.slots[var_n] + self._rm_tensor_from_accesses(var_n) var_t.asm_repack(self, var_n, obj) return None case ntn.Unwrap(ntn.Access(tns, mode, _)): @@ -419,16 +453,19 @@ def __call__(self, prgm): ext.result_format.lower_loop(self, idx, self(ext), body) return None case ntn.Declare(tns, init, op, shape): + self._thaw_tensor(tns.name, op) tns = self.resolve(tns) init_e = self(init) op_e = self(op) shape_e = [self(s) for s in shape] return tns.result_format.lower_declare(self, tns, init_e, op_e, shape_e) case ntn.Freeze(tns, op): + self._freeze_tensor(tns.name, op) tns = self.resolve(tns) op_e = self(op) return tns.result_format.lower_freeze(self, tns, op_e) case ntn.Thaw(tns, op): + self._thaw_tensor(tns.name, op) tns = self.resolve(tns) op_e = self(op) return tns.result_format.lower_thaw(self, tns, op_e) diff --git a/src/finchlite/finch_assembly/nodes.py b/src/finchlite/finch_assembly/nodes.py index bd1503bd..0479a68e 100644 --- a/src/finchlite/finch_assembly/nodes.py +++ b/src/finchlite/finch_assembly/nodes.py @@ -694,7 +694,7 @@ def __call__(self, prgm: AssemblyNode): raise NotImplementedError(f"Unrecognized lhs type: {lhs}") return None case GetAttr(obj, attr): - return f"getattr({obj}, {attr})" + return f"{obj}.{attr}" case SetAttr(obj, attr, val): return f"setattr({obj}, {attr})" case Call(Literal(_) as lit, args): diff --git a/src/finchlite/finch_assembly/struct.py b/src/finchlite/finch_assembly/struct.py index b855f87e..706e5768 100644 --- a/src/finchlite/finch_assembly/struct.py +++ b/src/finchlite/finch_assembly/struct.py @@ -30,6 +30,27 @@ def struct_setattr(self, obj, attr, value) -> None: setattr(obj, attr, value) return + @abstractmethod + def from_kwargs(self, **kwargs) -> "AssemblyStructFType": + """ + Protocol for constructing Finch tensors from keyword arguments. + Here are currently supported arguments. They are all optional, + each implementor decides which fields to select: + - lvl_t: LevelFType + - fill_value: np.number + - element_type: type + - position_type: type + - dimension_type: type + - shape_type: tuple[type, ...] + - buffer_factory: type + - buffer_type: BufferFType + - ndim: int + """ + ... + + @abstractmethod + def to_kwargs(self) -> dict: ... + @property def struct_fieldnames(self) -> list[str]: return [name for (name, _) in self.struct_fields] @@ -71,6 +92,12 @@ def struct_name(self): def struct_fields(self): return self._struct_fields + def from_kwargs(self, **kwargs) -> "TupleFType": + raise NotImplementedError + + def to_kwargs(self) -> dict: + raise NotImplementedError + def __call__(self, *args): assert all( isinstance(a, f) @@ -114,12 +141,18 @@ def struct_name(self): def struct_fields(self): return [(f"element_{i}", fmt) for i, fmt in enumerate(self._struct_formats)] - def __call__(self, *args): + def from_kwargs(self, **kwargs) -> "TupleFType": + raise NotImplementedError + + def to_kwargs(self) -> dict: + raise NotImplementedError + + def __call__(self, **kwargs): assert all( isinstance(a, f) - for a, f in zip(args, self.struct_fieldformats, strict=False) + for a, f in zip(kwargs.values(), self.struct_fieldformats, strict=False) ) - return tuple(args) + return tuple(kwargs.values()) @staticmethod @lru_cache diff --git a/src/finchlite/interface/fuse.py b/src/finchlite/interface/fuse.py index 715f3631..349b0aca 100644 --- a/src/finchlite/interface/fuse.py +++ b/src/finchlite/interface/fuse.py @@ -60,7 +60,7 @@ from ..algebra import Tensor, TensorPlaceholder from ..autoschedule import DefaultLogicOptimizer, LogicCompiler from ..codegen import NumbaCompiler -from ..compile import BufferizedNDArray, NotationCompiler +from ..compile import NotationCompiler from ..finch_logic import ( Alias, Field, @@ -103,9 +103,9 @@ def set_default_scheduler( ntn_interp = ntn.NotationInterpreter() def fn_compile(plan): - prgm, tables = optimizer(plan) + prgm, table_vars, tables = optimizer(plan) mod = ntn_interp(prgm) - args = provision_tensors(prgm, tables) + args = provision_tensors(prgm, table_vars, tables) return (mod.func(*args),) _DEFAULT_SCHEDULER = fn_compile @@ -116,10 +116,10 @@ def fn_compile(plan): asm_interp = asm.AssemblyInterpreter() def fn_compile(plan): - ntn_prgm, tables = optimizer(plan) + ntn_prgm, table_vars, tables = optimizer(plan) asm_prgm = notation_compiler(ntn_prgm) mod = asm_interp(asm_prgm) - args = provision_tensors(asm_prgm, tables) + args = provision_tensors(asm_prgm, table_vars, tables) return (mod.func(*args),) _DEFAULT_SCHEDULER = fn_compile @@ -132,12 +132,12 @@ def fn_compile(plan): def fn_compile(plan): # TODO: proper logging # print("Logic: \n", plan) - ntn_prgm, tables = optimizer(plan) + ntn_prgm, table_vars, tables = optimizer(plan) # print("Notation: \n", ntn_prgm) asm_prgm = notation_compiler(ntn_prgm) # print("Assembler: \n", asm_prgm) mod = numba_compiler(asm_prgm) - args = provision_tensors(asm_prgm, tables) + args = provision_tensors(asm_prgm, table_vars, tables) return (mod.func(*args),) _DEFAULT_SCHEDULER = fn_compile @@ -157,16 +157,19 @@ def get_default_scheduler(): return _DEFAULT_SCHEDULER -def provision_tensors(prgm: Any, tables: dict[Alias, Table]) -> list[Tensor]: +def provision_tensors( + prgm: Any, table_vars: dict[Alias, ntn.Variable], tables: dict[Alias, Table] +) -> list[Tensor]: args: list[Tensor] = [] dims_dict: dict[Field, int] = {} for arg in prgm.funcs[0].args: table = tables[Alias(arg.name)] + table_var = table_vars[Alias(arg.name)] match table: case Table(Literal(val), idxs): if isinstance(val, TensorPlaceholder): shape = tuple(dims_dict[field] for field in idxs) - tensor = BufferizedNDArray(np.zeros(dtype=val.dtype, shape=shape)) + tensor = table_var.type_(val=np.zeros(dtype=val.dtype, shape=shape)) else: for idx, field in enumerate(table.idxs): dims_dict[field] = val.shape[idx] diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index 638affaa..bc177ab3 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -331,27 +331,28 @@ def __ne__(self, other): register_property(np.ndarray, "asarray", "__attr__", lambda x: BufferizedNDArray(x)) +register_property(BufferizedNDArray, "asarray", "__attr__", lambda x: x) register_property(LazyTensor, "asarray", "__attr__", lambda x: x) -def asarray(arg: Any, format="bufferized") -> Any: +def asarray(arg: Any, format=None) -> Any: """ Convert given argument and return wrapper type instance. If input argument is already array type, return unchanged. Args: arg: The object to be converted. + format: The format for the result array. Returns: The array type result of the given object. """ - if format != "bufferized": - raise Exception(f"Only bufferized format is now supported, got: {format}") + if format is None: + if hasattr(arg, "asarray"): + return arg.asarray() + return query_property(arg, "asarray", "__attr__") - if hasattr(arg, "asarray"): - return arg.asarray() - - return query_property(arg, "asarray", "__attr__") + return format(arg) def defer(arr) -> LazyTensor: diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index 9c45ad8c..be73271f 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -1,4 +1,4 @@ -from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, tensor +from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, fiber_tensor from .level import ( DenseLevel, DenseLevelFType, @@ -19,5 +19,5 @@ "LevelFType", "dense", "element", - "tensor", + "fiber_tensor", ] diff --git a/src/finchlite/tensor/fiber_tensor.py b/src/finchlite/tensor/fiber_tensor.py index f618ae9b..d42fa302 100644 --- a/src/finchlite/tensor/fiber_tensor.py +++ b/src/finchlite/tensor/fiber_tensor.py @@ -1,12 +1,21 @@ from abc import ABC, abstractmethod +from copy import deepcopy from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, NamedTuple, TypeVar -from finchlite.algebra import Tensor, TensorFType -from finchlite.symbolic import FType, FTyped +import numpy as np +from finchlite.finch_assembly.struct import TupleFType -class LevelFType(FType, ABC): +from .. import finch_assembly as asm +from .. import finch_notation as ntn +from ..algebra import Tensor, register_property +from ..codegen.numpy_buffer import NumpyBuffer +from ..compile.lower import FinchTensorFType +from ..symbolic import FTyped + + +class LevelFType(FinchTensorFType, ABC): """ An abstract base class representing the ftype of levels. """ @@ -39,7 +48,7 @@ def element_type(self): @abstractmethod def shape_type(self): """ - Tuple of types of the dimensions in the shape + Tuple of types of the dimensions in the shape. """ ... @@ -59,6 +68,18 @@ def buffer_factory(self): """ ... + @property + @abstractmethod + def buffer_type(self): ... + + @property + @abstractmethod + def lvl_t(self): + """ + Property returning nested level + """ + ... + class Level(FTyped, ABC): """ @@ -68,12 +89,20 @@ class Level(FTyped, ABC): @property @abstractmethod - def shape(self): + def shape(self) -> tuple: """ Shape of the fibers in the structure. """ ... + @property + @abstractmethod + def stride(self) -> np.integer: ... + + @property + @abstractmethod + def val(self) -> Any: ... + @property def ndim(self): return self.ftype.ndim @@ -98,6 +127,10 @@ def position_type(self): def buffer_factory(self): return self.ftype.buffer_factory + @property + def buffer_type(self): + return self.ftype.buffer_type + Tp = TypeVar("Tp") @@ -132,6 +165,14 @@ def ftype(self): def shape(self): return self.lvl.shape + @property + def stride(self): + return self.lvl.stride + + @property + def val(self): + return self.lvl.val + @property def ndim(self): return self.lvl.ndim @@ -160,9 +201,18 @@ def buffer_factory(self): """ return self.lvl.buffer_factory + def to_numpy(self) -> np.ndarray: + # TODO: temporary for dense only. TBD in sparse_level PR + return np.reshape(self.lvl.val.arr, self.shape, copy=False) + + +class FiberTensorFields(NamedTuple): + lvl: asm.Variable # TODO: lvl is misleading - rename it + buf_s: asm.Slot + @dataclass(unsafe_hash=True) -class FiberTensorFType(TensorFType): +class FiberTensorFType(FinchTensorFType, asm.AssemblyStructFType): """ An abstract base class representing the ftype of a fiber tensor. @@ -170,59 +220,145 @@ class FiberTensorFType(TensorFType): lvl: a fiber allocator that manages the fibers in the tensor. """ - lvl: LevelFType - _position_type: type | None = None + lvl_t: LevelFType + position_type: type | None = None + + @property + def struct_name(self): + # TODO: include dt = np.dtype(self.buf_t.element_type) + return "FiberTensorFType" + + @property + def struct_fields(self): + return [ + ("lvl", self.lvl_t), + ("shape", TupleFType.from_tuple(self.shape_type)), + ] def __post_init__(self): - if self._position_type is None: - self._position_type = self.lvl.position_type + if self.position_type is None: + self.position_type = self.lvl_t.position_type - def __call__(self, shape): + def __call__(self, *, lvl=None, shape=None, val=None): """ Creates an instance of a FiberTensor with the given arguments. """ - return FiberTensor(self.lvl(shape), self.lvl.position_type(1)) - - @property - def shape(self): - return self.lvl.shape + if lvl is not None: + return FiberTensor(lvl, self.lvl_t.position_type(1)) + if shape is None: + shape = val.shape + val = NumpyBuffer(val.reshape(-1)) + return FiberTensor( + self.lvl_t(shape=shape, val=val), self.lvl_t.position_type(1) + ) + + def __str__(self): + return f"FiberTensorFType({self.lvl_t})" @property def ndim(self): - return self.lvl.ndim + return self.lvl_t.ndim @property def shape_type(self): - return self.lvl.shape_type + return self.lvl_t.shape_type @property def element_type(self): - return self.lvl.element_type + return self.lvl_t.element_type @property def fill_value(self): - return self.lvl.fill_value + return self.lvl_t.fill_value @property - def position_type(self): - return self._position_type + def buffer_factory(self): + return self.lvl_t.buffer_factory @property - def buffer_factory(self): + def buffer_type(self): + return self.lvl_t.buffer_type + + def from_kwargs(self, **kwargs) -> "FiberTensorFType": + pos_t = kwargs.get("position_type", self.position_type) + return FiberTensorFType(self.lvl_t.from_kwargs(**kwargs), pos_t) # type: ignore[abstract] + + def to_kwargs(self): + return { + "position_type": self.position_type, + "shape_type": self.shape_type, + } | self.lvl_t.to_kwargs() + + # TODO: temporary approach for suitable rep and traits + def add_levels(self, idxs: list[int]): + from .level.dense_level import dense + + copy = deepcopy(self) + lvl = copy + for idx in range(max(idxs) + 1): + if idx in idxs: + lvl.lvl_t = dense(lvl.lvl_t, dimension_type=np.intp) + lvl = lvl.lvl_t # type: ignore[assignment] + return copy + + # TODO: temporary approach for suitable rep and traits + def remove_levels(self, idxs: list[int]): + copy = deepcopy(self) + lvl = copy + for i in range(self.ndim): + if i in idxs: + lvl.lvl_t = lvl.lvl_t.lvl_t + lvl = lvl.lvl_t # type: ignore[assignment] + return copy + + def unfurl(self, ctx, tns, ext, mode, proto): + op = None + if isinstance(mode, ntn.Update): + op = mode.op + tns = ctx.resolve(tns).obj + obj = self.lvl_t.get_fields_class( + tns.lvl, tns.buf_s, 0, asm.Literal(self.position_type(0)), op + ) + return self.lvl_t.unfurl(ctx, ntn.Stack(obj, self.lvl_t), ext, mode, proto) + + def lower_freeze(self, ctx, tns, op): + return self.lvl_t.lower_freeze(ctx, tns.obj.buf_s, op) + + def lower_thaw(self, ctx, tns, op): + return self.lvl_t.lower_thaw(ctx, tns.obj.buf_s, op) + + def lower_unwrap(self, ctx, obj): + raise NotImplementedError + + def lower_increment(self, ctx, obj, val): + raise NotImplementedError + + def lower_declare(self, ctx, tns, init, op, shape): + return self.lvl_t.lower_declare(ctx, tns.obj.buf_s, init, op, shape) + + def asm_unpack(self, ctx, var_n, val): """ - Returns the ftype of the buffer used for the fibers. - This is typically a NumpyBufferFType or similar. + Unpack the into asm context. """ - return self.lvl.buffer_factory + val_lvl = asm.GetAttr(val, asm.Literal("lvl")) + buf_s = self.lvl_t.asm_unpack(ctx, var_n, val_lvl) + return FiberTensorFields(val_lvl, buf_s) + + def asm_repack(self, ctx, lhs, obj): + """ + Repack the buffer from the context. + """ + ctx.exec(asm.Repack(obj.buf_s)) + return -def tensor(lvl: LevelFType, position_type: type | None = None): +def fiber_tensor(lvl: LevelFType, position_type: type | None = None): """ Creates a FiberTensorFType with the given level ftype and position type. Args: lvl: The level ftype to be used for the tensor. - pos_type: The type of positions within the tensor. Defaults to None. + position_type: The type of positions within the tensor. Defaults to None. Returns: An instance of FiberTensorFType. @@ -230,3 +366,6 @@ def tensor(lvl: LevelFType, position_type: type | None = None): # mypy does not understand that dataclasses generate __hash__ and __eq__ # https://github.com/python/mypy/issues/19799 return FiberTensorFType(lvl, position_type) # type: ignore[abstract] + + +register_property(FiberTensor, "asarray", "__attr__", lambda x: x) diff --git a/src/finchlite/tensor/level/dense_level.py b/src/finchlite/tensor/level/dense_level.py index d57ca9e5..aec0b042 100644 --- a/src/finchlite/tensor/level/dense_level.py +++ b/src/finchlite/tensor/level/dense_level.py @@ -1,22 +1,47 @@ -from abc import ABC +import operator from dataclasses import dataclass -from typing import Any +from typing import Any, NamedTuple import numpy as np +from ... import finch_assembly as asm +from ... import finch_notation as ntn +from ...codegen import NumpyBufferFType +from ...compile import looplets as lplt from ..fiber_tensor import Level, LevelFType +class DenseLevelFields(NamedTuple): + lvl: asm.Variable + buf_s: NumpyBufferFType + nind: int + pos: asm.AssemblyNode + op: Any + + @dataclass(unsafe_hash=True) -class DenseLevelFType(LevelFType, ABC): - lvl: Any +class DenseLevelFType(LevelFType, asm.AssemblyStructFType): + _lvl_t: LevelFType dimension_type: Any = None + op: Any = None + + @property + def struct_name(self): + return "DenseLevelFType" + + @property + def struct_fields(self): + return [ + ("lvl", self.lvl_t), + ("dimension", self.dimension_type), + ("stride", self.dimension_type), + ] def __post_init__(self): if self.dimension_type is None: self.dimension_type = np.intp - def __call__(self, shape): + def __call__(self, *, lvl=None, dimension=None, stride=None, shape=None, val=None): """ Creates an instance of DenseLevel with the given ftype. Args: @@ -24,44 +49,138 @@ def __call__(self, shape): Returns: An instance of DenseLevel. """ - lvl = self.lvl(shape=shape[1:]) + if lvl is not None and dimension is not None: + return DenseLevel(self, lvl, dimension) + lvl = self.lvl_t(shape=shape[1:], val=val) return DenseLevel(self, lvl, self.dimension_type(shape[0])) + def __str__(self): + return f"DenseLevelFType({self.lvl_t})" + @property def ndim(self): - return 1 + self.lvl.ndim + return 1 + self.lvl_t.ndim + + @property + def lvl_t(self): + return self._lvl_t + + @lvl_t.setter + def lvl_t(self, value): + self._lvl_t = value @property def fill_value(self): - return self.lvl.fill_value + return self.lvl_t.fill_value @property def element_type(self): """ Returns the type of elements stored in the fibers. """ - return self.lvl.element_type + return self.lvl_t.element_type @property def shape_type(self): """ Returns the type of the shape of the fibers. """ - return (self.dimension_type, *self.lvl.shape_type) + return (self.dimension_type, *self.lvl_t.shape_type) @property def position_type(self): """ Returns the type of positions within the levels. """ - return self.lvl.position_type + return self.lvl_t.position_type + + @property + def buffer_type(self): + return self.lvl_t.buffer_type @property def buffer_factory(self): """ Returns the ftype of the buffer used for the fibers. """ - return self.lvl.buffer_factory + return self.lvl_t.buffer_factory + + def from_kwargs(self, **kwargs) -> "DenseLevelFType": + dimension_type = kwargs.get("dimension_type", self.position_type) + if "shape_type" in kwargs: + shape_type = kwargs["shape_type"] + dimension_type = shape_type[0] + kwargs["shape_type"] = shape_type[1:] + op = kwargs.get("op", self.op) + return DenseLevelFType(self.lvl_t.from_kwargs(**kwargs), dimension_type, op) # type: ignore[abstract] + + def to_kwargs(self): + return { + "dimension_type": self.position_type, + "op": self.op, + } | self.lvl_t.to_kwargs() + + def asm_unpack(self, ctx, var_n, val): + val_lvl = asm.GetAttr(val, asm.Literal("lvl")) + return self.lvl_t.asm_unpack(ctx, var_n, val_lvl) + + def get_fields_class(self, tns, buf_s, nind, pos, op): + return DenseLevelFields(tns, buf_s, nind, pos, op) + + def lower_declare(self, ctx, tns, init, op, shape): + return self.lvl_t.lower_declare(ctx, tns, init, op, shape) + + def lower_freeze(self, ctx, tns, op): + return self.lvl_t.lower_freeze(ctx, tns, op) + + def lower_thaw(self, ctx, tns, op): + return self.lvl_t.lower_thaw(ctx, tns, op) + + def lower_increment(self, ctx, obj, val): + raise NotImplementedError("DenseLevelFType does not support lower_increment.") + + def lower_unwrap(self, ctx, obj): + raise NotImplementedError("DenseLevelFType does not support lower_unwrap.") + + def unfurl(self, ctx, tns, ext, mode, proto): + def child_accessor(ctx, idx): + pos_2 = asm.Variable( + ctx.freshen(ctx.idx, f"_pos_{self.ndim - 1}"), self.position_type + ) + ctx.exec( + asm.Assign( + pos_2, + asm.Call( + asm.Literal(operator.add), + [ + tns.obj.pos, + asm.Call( + asm.Literal(operator.mul), + [ + asm.GetAttr(tns.obj.lvl, asm.Literal("stride")), + asm.Variable(ctx.idx.name, ctx.idx.type_), + ], + ), + ], + ), + ) + ) + return ntn.Stack( + self.lvl_t.get_fields_class( + asm.GetAttr(tns.obj.lvl, asm.Literal("lvl")), + tns.obj.buf_s, + tns.obj.nind + 1, + pos_2, + tns.obj.op, + ), + self.lvl_t, + ) + + return lplt.Lookup( + body=lambda ctx, idx: lplt.Leaf( + body=lambda ctx: child_accessor(ctx, idx), + ) + ) def dense(lvl, dimension_type=None): @@ -71,17 +190,29 @@ def dense(lvl, dimension_type=None): @dataclass class DenseLevel(Level): """ - A class representing the leaf level of Finch tensors. + A class representing dense level. """ _format: DenseLevelFType - lvl: Any - dimension: Any + lvl: Level + dimension: np.intp + pos: asm.AssemblyNode | None = None @property - def shape(self): + def shape(self) -> tuple: return (self.dimension, *self.lvl.shape) @property - def ftype(self): + def stride(self) -> np.integer: + stride = self.lvl.stride + if self.lvl.ndim == 0: + return stride + return self.lvl.shape[0] * stride + + @property + def ftype(self) -> DenseLevelFType: return self._format + + @property + def val(self) -> Any: + return self.lvl.val diff --git a/src/finchlite/tensor/level/element_level.py b/src/finchlite/tensor/level/element_level.py index d2d858a8..e15bb2f2 100644 --- a/src/finchlite/tensor/level/element_level.py +++ b/src/finchlite/tensor/level/element_level.py @@ -1,66 +1,129 @@ -from dataclasses import dataclass -from typing import Any +from dataclasses import asdict, dataclass +from typing import Any, NamedTuple import numpy as np +from ... import finch_assembly as asm from ...codegen import NumpyBufferFType from ...symbolic import FType, ftype from ..fiber_tensor import Level, LevelFType +class ElementLevelFields(NamedTuple): + lvl: asm.Variable + buf_s: NumpyBufferFType + nind: int + pos: asm.AssemblyNode + op: Any + + @dataclass(unsafe_hash=True) -class ElementLevelFType(LevelFType): - _fill_value: Any - _element_type: type | FType | None = None - _position_type: type | FType | None = None - _buffer_factory: Any = NumpyBufferFType - val_format: Any = None +class ElementLevelFType(LevelFType, asm.AssemblyStructFType): + fill_value: Any = None + element_type: type | FType | None = None + position_type: type | FType | None = None + buffer_factory: Any = NumpyBufferFType + buffer_type: Any = None + + @property + def struct_name(self): + return "ElementLevelFType" + + @property + def struct_fields(self): + return [ + ("val", self.buffer_type), + ] def __post_init__(self): - if self._element_type is None: - self._element_type = ftype(self._fill_value) - if self.val_format is None: - self.val_format = self._buffer_factory(self._element_type) - if self._position_type is None: - self._position_type = np.intp - self._element_type = self.val_format.element_type - self._fill_value = self._element_type(self._fill_value) - - def __call__(self, shape=()): + if self.element_type is None: + self.element_type = ftype(self.fill_value) + if self.buffer_type is None: + self.buffer_type = self.buffer_factory(self.element_type) + if self.position_type is None: + self.position_type = np.intp + self.element_type = self.buffer_type.element_type + self.fill_value = self.element_type(self.fill_value) + + def __call__(self, shape=(), val=None): """ Creates an instance of ElementLevel with the given ftype. Args: - fmt: The ftype to be used for the level. + shape: Should be always `()`, used for validation. + val: The value to store in the ElementLevel instance. Returns: An instance of ElementLevel. """ if len(shape) != 0: raise ValueError("ElementLevelFType must be called with an empty shape.") - return ElementLevel(self) + return ElementLevel(self, val) + + def __str__(self): + return f"ElementLevelFType(fv={self.fill_value})" @property def ndim(self): return 0 @property - def fill_value(self): - return self._fill_value + def lvl_t(self): + raise Exception("ElementLevel is a leaf level.") - @property - def element_type(self): - return self._element_type + def from_kwargs(self, **kwargs) -> "ElementLevelFType": + f_v = kwargs.get("fill_value", self.fill_value) + e_t = kwargs.get("element_type", self.element_type) + p_t = kwargs.get("position_type", self.position_type) + b_f = kwargs.get("buffer_factory", self.buffer_factory) + v_f = kwargs.get("buffer_type", self.buffer_type) + return ElementLevelFType(f_v, e_t, p_t, b_f, v_f) # type: ignore[abstract] - @property - def position_type(self): - return self._position_type + def to_kwargs(self): + return asdict(self) @property def shape_type(self): return () - @property - def buffer_factory(self): - return self._buffer_factory + def asm_unpack(self, ctx, var_n, val): + buf = asm.Variable(f"{var_n}_buf", self.buffer_type) + buf_e = asm.GetAttr(val, asm.Literal("val")) + ctx.exec(asm.Assign(buf, buf_e)) + buf_s = asm.Slot(f"{var_n}_buf_slot", self.buffer_type) + ctx.exec(asm.Unpack(buf_s, buf)) + return buf_s + + def get_fields_class(self, tns, buf_s, nind, pos, op): + return ElementLevelFields(tns, buf_s, nind, pos, op) + + def lower_declare(self, ctx, tns, init, op, shape): + i_var = asm.Variable("i", self.buffer_type.length_type) + body = asm.Store(tns, i_var, asm.Literal(init.val)) + ctx.exec(asm.ForLoop(i_var, asm.Literal(np.intp(0)), asm.Length(tns), body)) + + def lower_unwrap(self, ctx, obj): + return asm.Load(obj.buf_s, obj.pos) + + def lower_increment(self, ctx, obj, val): + lowered_pos = asm.Variable(obj.pos.name, obj.pos.type) + ctx.exec( + asm.Store( + obj.buf_s, + lowered_pos, + asm.Call( + asm.Literal(obj.op.val), + [asm.Load(obj.buf_s, lowered_pos), val], + ), + ) + ) + + def lower_freeze(self, ctx, tns, op): + return tns + + def lower_thaw(self, ctx, tns, op): + return tns + + def unfurl(self, ctx, tns, ext, mode, proto): + raise NotImplementedError("ElementLevelFType does not support unfurl.") def element( @@ -68,7 +131,7 @@ def element( element_type=None, position_type=None, buffer_factory=None, - val_format=None, + buffer_type=None, ): """ Creates an ElementLevelFType with the given parameters. @@ -78,16 +141,17 @@ def element( element_type: The type of elements stored in the level. position_type: The type of positions within the level. buffer_factory: The factory used to create buffers for the level. + buffer_type: Format of the value stored in the level. Returns: An instance of ElementLevelFType. """ return ElementLevelFType( - _fill_value=fill_value, - _element_type=element_type, - _position_type=position_type, - _buffer_factory=buffer_factory, - val_format=val_format, + fill_value=fill_value, + element_type=element_type, + position_type=position_type, + buffer_factory=buffer_factory, + buffer_type=buffer_type, ) @@ -98,16 +162,26 @@ class ElementLevel(Level): """ _format: ElementLevelFType - val: Any | None = None + _val: Any | None = None def __post_init__(self): - if self.val is None: - self.val = self._format.val_format(len=0, dtype=self._format.element_type()) + if self._val is None: + self._val = self._format.buffer_type( + len=0, dtype=self._format.element_type() + ) @property - def shape(self): + def shape(self) -> tuple: return () @property - def ftype(self): + def stride(self) -> np.integer: + return np.intp(1) # TODO: add dimension_type to element_level.py + + @property + def ftype(self) -> ElementLevelFType: return self._format + + @property + def val(self) -> Any: + return self._val diff --git a/src/finchlite/tensor/level/sparse_list_level.py b/src/finchlite/tensor/level/sparse_list_level.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/reference/test_matrix_multiplication_regression.txt b/tests/reference/test_matrix_multiplication_regression.txt index e66d63e1..1e718225 100644 --- a/tests/reference/test_matrix_multiplication_regression.txt +++ b/tests/reference/test_matrix_multiplication_regression.txt @@ -3,29 +3,29 @@ def matmul(C: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64, A: Bufferized n: ExtentFType(start=int64, end=int64) = dimension(B, 1) p: ExtentFType(start=int64, end=int64) = dimension(A, 1) A_: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = A - A__stride_0: int64 = getattr(getattr(A_, strides), element_0) - A__stride_1: int64 = getattr(getattr(A_, strides), element_1) - A__buf: np_buf_t(float64) = getattr(A_, buf) + A__stride_0: int64 = A_.strides.element_0 + A__stride_1: int64 = A_.strides.element_1 + A__buf: np_buf_t(float64) = A_.val A__buf_slot: np_buf_t(float64) = unpack(A__buf) B_: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = B - B__stride_0: int64 = getattr(getattr(B_, strides), element_0) - B__stride_1: int64 = getattr(getattr(B_, strides), element_1) - B__buf: np_buf_t(float64) = getattr(B_, buf) + B__stride_0: int64 = B_.strides.element_0 + B__stride_1: int64 = B_.strides.element_1 + B__buf: np_buf_t(float64) = B_.val B__buf_slot: np_buf_t(float64) = unpack(B__buf) C_: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = C - C__stride_0: int64 = getattr(getattr(C_, strides), element_0) - C__stride_1: int64 = getattr(getattr(C_, strides), element_1) - C__buf: np_buf_t(float64) = getattr(C_, buf) + C__stride_0: int64 = C_.strides.element_0 + C__stride_1: int64 = C_.strides.element_1 + C__buf: np_buf_t(float64) = C_.val C__buf_slot: np_buf_t(float64) = unpack(C__buf) for i in range(0, length(slot(C__buf_slot, np_buf_t(float64)))): store(slot(C__buf_slot, np_buf_t(float64)), i, 0.0) - for i in range(getattr(m, start), getattr(m, end)): + for i in range(m.start, m.end): i__pos: int64 = add(0, mul(A__stride_0, i)) i__pos_2: int64 = add(0, mul(C__stride_0, i)) - for k in range(getattr(p, start), getattr(p, end)): + for k in range(p.start, p.end): k__pos: int64 = add(i__pos, mul(A__stride_1, k)) k__pos_2: int64 = add(0, mul(B__stride_0, k)) - for j in range(getattr(n, start), getattr(n, end)): + for j in range(n.start, n.end): j__pos: int64 = add(k__pos_2, mul(B__stride_1, j)) j__pos_2: int64 = add(i__pos_2, mul(C__stride_1, j)) a_ik: float64 = load(slot(A__buf_slot, np_buf_t(float64)), k__pos) diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 202c406e..14d79760 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -13,7 +13,7 @@ import finchlite import finchlite.finch_assembly as asm -from finchlite import ftype +from finchlite import dense, element, fiber_tensor, ftype from finchlite.codegen import ( CCompiler, CGenerator, @@ -30,6 +30,7 @@ deserialize_from_numba, serialize_to_numba, ) +from finchlite.compile import BufferizedNDArray def test_add_function(): @@ -909,15 +910,29 @@ def test_np_numba_serialization(value, np_type): assert deserialize_from_numba(np_type, constructed, serialized) is None -def test_e2e_numba(): +@pytest.mark.parametrize( + "fmt_fn", + [ + lambda x: BufferizedNDArray, + lambda dtype: fiber_tensor( + dense(dense(element(dtype(0), dtype, np.intp, NumpyBufferFType))) + ), + ], +) +@pytest.mark.parametrize("dtype", [np.float64, np.int64]) +def test_e2e_numba(fmt_fn, dtype): ctx = finchlite.get_default_scheduler() # TODO: as fixture finchlite.set_default_scheduler(mode=finchlite.Mode.COMPILE_NUMBA) - a = np.array([[2, 0, 3], [1, 3, -1], [1, 1, 8]], dtype=np.float64) - b = np.array([[4, 1, 9], [2, 2, 4], [4, 4, -5]], dtype=np.float64) + a = np.array([[2, 0, 3], [1, 3, -1], [1, 1, 8]], dtype=dtype) + b = np.array([[4, 1, 9], [2, 2, 4], [4, 4, -5]], dtype=dtype) + + fmt = fmt_fn(dtype) + aa = fmt(val=a) + bb = fmt(val=b) - wa = finchlite.defer(a) - wb = finchlite.defer(b) + wa = finchlite.defer(aa) + wb = finchlite.defer(bb) plan = finchlite.matmul(wa, wb) result = finchlite.compute(plan) diff --git a/tests/test_logic_compiler.py b/tests/test_logic_compiler.py index 03b1be96..81c574fb 100644 --- a/tests/test_logic_compiler.py +++ b/tests/test_logic_compiler.py @@ -118,9 +118,9 @@ def test_logic_compiler(): ) bufferized_ndarray_ftype = BufferizedNDArrayFType( - buf_t=NumpyBufferFType(np.dtype(int)), + buffer_type=NumpyBufferFType(np.dtype(int)), ndim=np.intp(2), - strides_t=TupleFType.from_tuple((np.intp, np.intp)), + dimension_type=TupleFType.from_tuple((np.intp, np.intp)), ) expected_program = Module( @@ -314,13 +314,13 @@ def test_logic_compiler(): ) ) - program, tables = LogicCompiler()(plan) + program, table_vars, tables = LogicCompiler()(plan) assert program == expected_program mod = NotationInterpreter()(program) - args = provision_tensors(program, tables) + args = provision_tensors(program, table_vars, tables) result = mod.func(*args) expected = np.matmul(args[0].to_numpy(), args[1].to_numpy(), dtype=float) diff --git a/tests/test_notation_interpreter.py b/tests/test_notation_interpreter.py index 2fb0fbf0..a9b6a291 100644 --- a/tests/test_notation_interpreter.py +++ b/tests/test_notation_interpreter.py @@ -92,11 +92,11 @@ def test_matrix_multiplication(a, b): i, m, ntn.Loop( - j, - n, + k, + p, ntn.Loop( - k, - p, + j, + n, ntn.Block( ( ntn.Assign( diff --git a/tests/test_regressions.py b/tests/test_regressions.py index ad61291e..c66ce3d0 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -106,5 +106,5 @@ def test_tree_regression(file_regression): ), ) ) - program, tables = LogicCompiler()(plan) + program, table_vars, tables = LogicCompiler()(plan) file_regression.check(str(program), extension=".txt") diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 05d89ead..477a01bf 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -4,14 +4,18 @@ DenseLevelFType, ElementLevelFType, FiberTensorFType, + NumpyBuffer, NumpyBufferFType, + dense, + element, + fiber_tensor, ) def test_fiber_tensor_attributes(): fmt = FiberTensorFType(DenseLevelFType(ElementLevelFType(0.0))) shape = (3,) - a = fmt(shape) + a = fmt(shape=shape) # Check shape attribute assert a.shape == shape @@ -33,3 +37,11 @@ def test_fiber_tensor_attributes(): # Check buffer_format exists assert a.buffer_factory == NumpyBufferFType + + +def test_fiber_tensor(): + fmt = fiber_tensor( + dense(dense(element(np.int64(0), np.int64, np.intp, NumpyBufferFType))) + ) + + fmt(shape=(3, 4), val=NumpyBuffer(np.arange(12)))