11from abc import ABC , abstractmethod
22from dataclasses import dataclass
3- from typing import Generic , TypeVar
3+ from typing import Generic , NamedTuple , TypeVar
44
5- from finchlite .algebra import Tensor , TensorFType
6- from finchlite .symbolic import FType , FTyped
5+ import numpy as np
6+
7+ from .. import finch_assembly as asm
8+ from .. import finch_notation as ntn
9+ from ..algebra import Tensor , register_property
10+ from ..codegen .numpy_buffer import NumpyBuffer
11+ from ..compile .lower import FinchTensorFType
12+ from ..symbolic import FType , FTyped
713
814
915class LevelFType (FType , ABC ):
@@ -161,8 +167,14 @@ def buffer_factory(self):
161167 return self .lvl .buffer_factory
162168
163169
170+ class FiberTensorFields (NamedTuple ):
171+ stride : tuple [asm .Variable , ...]
172+ buf : asm .Variable
173+ buf_s : asm .Slot
174+
175+
164176@dataclass (unsafe_hash = True )
165- class FiberTensorFType (TensorFType ):
177+ class FiberTensorFType (FinchTensorFType ):
166178 """
167179 An abstract base class representing the ftype of a fiber tensor.
168180
@@ -177,11 +189,14 @@ def __post_init__(self):
177189 if self ._position_type is None :
178190 self ._position_type = self .lvl .position_type
179191
180- def __call__ (self , shape ):
192+ def __call__ (self , * , shape = None , val = None ):
181193 """
182194 Creates an instance of a FiberTensor with the given arguments.
183195 """
184- return FiberTensor (self .lvl (shape ), self .lvl .position_type (1 ))
196+ if shape is None :
197+ shape = val .shape
198+ val = NumpyBuffer (val .reshape [- 1 ])
199+ return FiberTensor (self .lvl (shape , val ), self .lvl .position_type (1 ))
185200
186201 @property
187202 def shape (self ):
@@ -209,24 +224,82 @@ def position_type(self):
209224
210225 @property
211226 def buffer_factory (self ):
212- """
213- Returns the ftype of the buffer used for the fibers.
214- This is typically a NumpyBufferFType or similar.
215- """
216227 return self .lvl .buffer_factory
217228
218-
219- def tensor (lvl : LevelFType , position_type : type | None = None ):
229+ def unfurl (self , ctx , tns , ext , mode , proto ):
230+ op = None
231+ if isinstance (mode , ntn .Update ):
232+ op = mode .op
233+ tns = ctx .resolve (tns ).obj
234+ acc_t = self .lvl .unfurl (
235+ ctx ,
236+ )
237+
238+ # acc_t = BufferizedNDArrayAccessorFType(self, 0, self.buf_t.length_type, op)
239+ # obj = BufferizedNDArrayAccessorFields(
240+ # tns, 0, asm.Literal(self.buf_t.length_type(0)), op
241+ # )
242+ return acc_t .unfurl (ctx , ntn .Stack (obj , acc_t ), ext , mode , proto )
243+
244+ def lower_freeze (self , ctx , tns , op ):
245+ raise NotImplementedError
246+
247+ def lower_thaw (self , ctx , tns , op ):
248+ raise NotImplementedError
249+
250+ def lower_unwrap (self , ctx , obj ):
251+ raise NotImplementedError
252+
253+ def lower_increment (self , ctx , obj , val ):
254+ raise NotImplementedError
255+
256+ def lower_declare (self , ctx , tns , init , op , shape ):
257+ i_var = asm .Variable ("i" , self .buffer_factory .length_type )
258+ body = asm .Store (
259+ tns .obj .buf_s ,
260+ i_var ,
261+ asm .Literal (init .val ),
262+ )
263+ ctx .exec (
264+ asm .ForLoop (i_var , asm .Literal (np .intp (0 )), asm .Length (tns .obj .buf_s ), body )
265+ )
266+ return
267+
268+ def asm_unpack (self , ctx , var_n , val ):
269+ """
270+ Unpack the into asm context.
271+ """
272+ stride = []
273+ shape_type = self .shape_type
274+ for i in range (self .ndim ):
275+ stride_i = asm .Variable (f"{ var_n } _stride_{ i } " , shape_type [i ])
276+ stride .append (stride_i )
277+ stride_e = asm .GetAttr (val , asm .Literal ("strides" ))
278+ stride_i_e = asm .GetAttr (stride_e , asm .Literal (f"element_{ i } " ))
279+ ctx .exec (asm .Assign (stride_i , stride_i_e ))
280+ buf = asm .Variable (f"{ var_n } _buf" , self .buffer_factory )
281+ buf_e = asm .GetAttr (val , asm .Literal ("buf" ))
282+ ctx .exec (asm .Assign (buf , buf_e ))
283+ buf_s = asm .Slot (f"{ var_n } _buf_slot" , self .buffer_factory )
284+ ctx .exec (asm .Unpack (buf_s , buf ))
285+
286+ return FiberTensorFields (tuple (stride ), buf , buf_s )
287+
288+
289+ def fiber_tensor (lvl : LevelFType , position_type : type | None = None ):
220290 """
221291 Creates a FiberTensorFType with the given level ftype and position type.
222292
223293 Args:
224294 lvl: The level ftype to be used for the tensor.
225- pos_type : The type of positions within the tensor. Defaults to None.
295+ position_type : The type of positions within the tensor. Defaults to None.
226296
227297 Returns:
228298 An instance of FiberTensorFType.
229299 """
230300 # mypy does not understand that dataclasses generate __hash__ and __eq__
231301 # https://github.com/python/mypy/issues/19799
232302 return FiberTensorFType (lvl , position_type ) # type: ignore[abstract]
303+
304+
305+ register_property (FiberTensor , "asarray" , "__attr__" , lambda x : x )
0 commit comments