11from abc import ABC , abstractmethod
22from dataclasses import dataclass
3- from typing import Generic , TypeVar
3+ from typing import Generic , NamedTuple , TypeVar
44
5- from finchlite .algebra import Tensor , TensorFType
6- from finchlite .symbolic import FType , FTyped
5+ import numpy as np
6+
7+ from .. import finch_assembly as asm
8+ from .. import finch_notation as ntn
9+ from ..algebra import Tensor , register_property
10+ from ..codegen .numpy_buffer import NumpyBuffer
11+ from ..compile .lower import FinchTensorFType
12+ from ..symbolic import FType , FTyped
713
814
915class LevelFType (FType , ABC ):
@@ -68,7 +74,7 @@ class Level(FTyped, ABC):
6874
6975 @property
7076 @abstractmethod
71- def shape (self ):
77+ def shape (self ) -> tuple :
7278 """
7379 Shape of the fibers in the structure.
7480 """
@@ -161,8 +167,14 @@ def buffer_factory(self):
161167 return self .lvl .buffer_factory
162168
163169
170+ class FiberTensorFields (NamedTuple ):
171+ stride : tuple [asm .Variable , ...]
172+ buf : asm .Variable
173+ buf_s : asm .Slot
174+
175+
164176@dataclass (unsafe_hash = True )
165- class FiberTensorFType (TensorFType ):
177+ class FiberTensorFType (FinchTensorFType ):
166178 """
167179 An abstract base class representing the ftype of a fiber tensor.
168180
@@ -177,11 +189,17 @@ def __post_init__(self):
177189 if self ._position_type is None :
178190 self ._position_type = self .lvl .position_type
179191
180- def __call__ (self , shape ):
192+ def __call__ (self , * , shape = None , val = None ):
181193 """
182194 Creates an instance of a FiberTensor with the given arguments.
183195 """
184- return FiberTensor (self .lvl (shape ), self .lvl .position_type (1 ))
196+ if shape is None :
197+ shape = val .shape
198+ val = NumpyBuffer (val .reshape [- 1 ])
199+ return FiberTensor (self .lvl (shape , val ), self .lvl .position_type (1 ))
200+
201+ def __str__ (self ):
202+ return f"FiberTensorFType({ self .lvl } )"
185203
186204 @property
187205 def shape (self ):
@@ -209,24 +227,82 @@ def position_type(self):
209227
210228 @property
211229 def buffer_factory (self ):
230+ return self .lvl .buffer_factory
231+
232+ def unfurl (self , ctx , tns , ext , mode , proto ):
233+ op = None
234+ if isinstance (mode , ntn .Update ):
235+ op = mode .op
236+ tns = ctx .resolve (tns ).obj
237+ obj = self .lvl .get_fields_class (tns , 0 , asm .Literal (self .position_type (0 )), op )
238+ return self .lvl .unfurl (ctx , ntn .Stack (obj , self .lvl ), ext , mode , proto )
239+
240+ def lower_freeze (self , ctx , tns , op ):
241+ return tns
242+
243+ def lower_thaw (self , ctx , tns , op ):
244+ raise NotImplementedError
245+
246+ def lower_unwrap (self , ctx , obj ):
247+ raise NotImplementedError
248+
249+ def lower_increment (self , ctx , obj , val ):
250+ raise NotImplementedError
251+
252+ def lower_declare (self , ctx , tns , init , op , shape ):
253+ i_var = asm .Variable ("i" , self .buffer_factory .length_type )
254+ body = asm .Store (
255+ tns .obj .buf_s ,
256+ i_var ,
257+ asm .Literal (init .val ),
258+ )
259+ ctx .exec (
260+ asm .ForLoop (i_var , asm .Literal (np .intp (0 )), asm .Length (tns .obj .buf_s ), body )
261+ )
262+ return
263+
264+ def asm_unpack (self , ctx , var_n , val ):
212265 """
213- Returns the ftype of the buffer used for the fibers.
214- This is typically a NumpyBufferFType or similar.
266+ Unpack the into asm context.
215267 """
216- return self .lvl .buffer_factory
268+ stride = []
269+ shape_type = self .shape_type
270+ for i in range (self .ndim ):
271+ stride_i = asm .Variable (f"{ var_n } _stride_{ i } " , shape_type [i ])
272+ stride .append (stride_i )
273+ stride_e = asm .GetAttr (val , asm .Literal ("strides" ))
274+ stride_i_e = asm .GetAttr (stride_e , asm .Literal (f"element_{ i } " ))
275+ ctx .exec (asm .Assign (stride_i , stride_i_e ))
276+ buf = asm .Variable (f"{ var_n } _buf" , self .buffer_factory )
277+ buf_e = asm .GetAttr (val , asm .Literal ("buf" ))
278+ ctx .exec (asm .Assign (buf , buf_e ))
279+ buf_s = asm .Slot (f"{ var_n } _buf_slot" , self .buffer_factory )
280+ ctx .exec (asm .Unpack (buf_s , buf ))
281+
282+ return FiberTensorFields (tuple (stride ), buf , buf_s )
283+
284+ def asm_repack (self , ctx , lhs , obj ):
285+ """
286+ Repack the buffer from the context.
287+ """
288+ ctx .exec (asm .Repack (obj .buf_s ))
289+ return
217290
218291
219- def tensor (lvl : LevelFType , position_type : type | None = None ):
292+ def fiber_tensor (lvl : LevelFType , position_type : type | None = None ):
220293 """
221294 Creates a FiberTensorFType with the given level ftype and position type.
222295
223296 Args:
224297 lvl: The level ftype to be used for the tensor.
225- pos_type : The type of positions within the tensor. Defaults to None.
298+ position_type : The type of positions within the tensor. Defaults to None.
226299
227300 Returns:
228301 An instance of FiberTensorFType.
229302 """
230303 # mypy does not understand that dataclasses generate __hash__ and __eq__
231304 # https://github.com/python/mypy/issues/19799
232305 return FiberTensorFType (lvl , position_type ) # type: ignore[abstract]
306+
307+
308+ register_property (FiberTensor , "asarray" , "__attr__" , lambda x : x )
0 commit comments