1414 return_type ,
1515)
1616from ..codegen import NumpyBufferFType
17- from ..compile import BufferizedNDArrayFType , ExtentFType , dimension
17+ from ..compile import ExtentFType , dimension
1818from ..finch_assembly import TupleFType
1919from ..finch_logic import (
2020 Aggregate ,
@@ -204,7 +204,7 @@ def __call__(
204204 return ntn .Assign (
205205 ntn .Variable (
206206 name ,
207- BufferizedNDArrayFType (
207+ type ( val ) (
208208 NumpyBufferFType (val .dtype ),
209209 val .ndim ,
210210 TupleFType .from_tuple (val .shape_type ),
@@ -484,12 +484,12 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
484484 )
485485 )
486486
487- return BufferizedNDArrayFType (
488- buf_t = NumpyBufferFType ( dtype ),
489- ndim = np . intp ( len ( result_fields )),
490- strides_t = TupleFType . from_tuple (
491- tuple ( field_type_map [ f ] for f in result_fields )
492- ),
487+ # TODO: infer result rep from args
488+ result_rep = type ( args_suitable_reps_fields [ 0 ][ 0 ])
489+ return result_rep (
490+ NumpyBufferFType ( dtype ),
491+ np . intp ( len ( result_fields )),
492+ TupleFType . from_tuple ( tuple ( field_type_map [ f ] for f in result_fields ) ),
493493 )
494494 case Aggregate (Literal (op ), init , arg , idxs ):
495495 init_suitable_rep = find_suitable_rep (init , table_vars )
@@ -504,10 +504,8 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
504504 for f , st in zip (arg .fields , arg_suitable_rep .shape_type , strict = True )
505505 if f not in idxs
506506 )
507- return BufferizedNDArrayFType (
508- buf_t = buf_t ,
509- ndim = np .intp (len (strides_t )),
510- strides_t = TupleFType .from_tuple (strides_t ),
507+ return type (arg_suitable_rep )(
508+ buf_t , np .intp (len (strides_t )), TupleFType .from_tuple (strides_t )
511509 )
512510 case LogicTree () as tree :
513511 for child in tree .children :
0 commit comments