77from tensorforge .common .matrix .spp import FullSPP , BoundingBoxSPP , ListSPP
88from tensorforge .common .matrix .boundingbox import BoundingBox as BBox
99from tensorforge .generators .generator import Generator as TensorForgeGenerator
10- from tensorforge .generators .descriptions import MultilinearDescr , ElementwiseDescr , GridBarrierDescr , GridFenceDescr
10+ from tensorforge .generators .descriptions import MultilinearDescr , ElementwiseDescr , GridBarrierDescr , GridFenceDescr , RegionDescription
1111
1212from tensorforge .ir .data .variable import TensorView , TensorAlloc
1313from tensorforge .ir .data .variable import TensorData
@@ -27,6 +27,9 @@ def __init__(self, arch):
2727 self ._ir_list = []
2828 self ._tensor_list = {}
2929
30+ # TODO: maybe remove again
31+ self ._prefix = ""
32+
3033 def add_operation (self , dest , ops , target , permute , add ):
3134 self ._cache_matrices (dest , ops , target , permute )
3235 can_be_aligned = self ._can_be_aligned (dest , ops , target , permute )
@@ -96,9 +99,9 @@ def get_tensor(self, op, can_be_aligned, dims):
9699 if isinstance (op , (float , int )):
97100 return SubTensor (tensor = Tensor ([], Addressing .SCALAR , data = [op ]))
98101 elif self .is_scalar (op ):
99- return SubTensor (self ._cache [op .name ()])
102+ return SubTensor (self ._cache [f' { self . _prefix } { op .name ()} ' ])
100103 else :
101- tensor = self ._cache [op .name ]
104+ tensor = self ._cache [f' { self . _prefix } { op .name } ' ]
102105 currentPreShape = BBox ([s for s , _ in op .eqspp .nnzbounds ()], [e + 1 for _ , e in op .eqspp .nnzbounds ()])
103106
104107 tml = op .memoryLayout
@@ -136,10 +139,10 @@ def assigner(pretensor):
136139 if self .is_scalar (pretensor ):
137140 self .make_tensor (pretensor , False , None )
138141 indicesIndexed [pretensor .name ()] = []
139- subTensor = SubTensor (self ._cache [pretensor .name ()], BBox ([], []))
142+ subTensor = SubTensor (self ._cache [f' { self . _prefix } { pretensor .name ()} ' ], BBox ([], []))
140143 else :
141144 bbox = BBox ([s for s , _ in pretensor .eqspp ().nnzbounds ()], [e + 1 for _ , e in pretensor .eqspp ().nnzbounds ()])
142- subTensor = SubTensor (self ._cache [pretensor .name ()], bbox )
145+ subTensor = SubTensor (self ._cache [f' { self . _prefix } { pretensor .name ()} ' ], bbox )
143146 return subTensor , indicesIndexed [pretensor .name ()]
144147
145148 for statement in statements :
@@ -203,13 +206,17 @@ def make_tensor(self, op, can_be_aligned, dims):
203206 entry = self ._get_tensorforge_matrix (op )
204207 entry_name = op .name
205208
209+ entry_name = f'{ self ._prefix } { entry_name } '
210+
206211 if not (entry_name in self ._cache and entry .is_same (self ._cache [entry_name ])):
207212 self ._cache [entry_name ] = entry
208213
209214 def tensor_ref (self , d ):
210215 name = d ['name' ]
211216 eqspp = d ['spp' ]
212217
218+ name = f'{ self ._prefix } { name } '
219+
213220 assert (name in self ._cache )
214221
215222 return SubTensor (self ._cache [name ], self ._cache [name ].bbox )
@@ -226,6 +233,8 @@ def tensor_ref_new(self, d):
226233
227234 def add_tensor (self , d ):
228235 name = d ['name' ]
236+ name = f'{ self ._prefix } { name } '
237+
229238 datatype = Datatype .ytt2enum (d ['datatype' ])
230239
231240 datatype_new = BaseDatatype .ytt2enum (d ['datatype' ])
@@ -276,16 +285,17 @@ def _cache_matrices(self, dest, ops, target, permute):
276285
277286 if dest .is_temporary : # (dest is never a scalar---for the time being)
278287 self .make_tensor (dest , can_be_aligned , [i for i in range (len (dest .indices ))])
279- self ._tmp_matrices [dest .name ] = self ._cache [dest .name ]
288+ self ._tmp_matrices [f' { self . _prefix } { dest .name } ' ] = self ._cache [f' { self . _prefix } { dest .name } ' ]
280289 else :
281290 self .make_tensor (dest , can_be_aligned , [i for i in range (len (dest .indices ))])
282291
283292
284293
285294 def _add_scalar (self , scalar ):
286- tensor = Tensor ([], Addressing .SCALAR , alias = scalar .name (), datatype = self ._datatype (scalar .datatype ))
287- self ._tmp_matrices [scalar .name ()] = tensor # SubTensor(tensor, tensor.bbox)
288- return self ._tmp_matrices [scalar .name ()]
295+ name = f'{ self ._prefix } { scalar .name ()} '
296+ tensor = Tensor ([], Addressing .SCALAR , alias = name , datatype = self ._datatype (scalar .datatype ))
297+ self ._tmp_matrices [name ] = tensor # SubTensor(tensor, tensor.bbox)
298+ return self ._tmp_matrices [name ]
289299
290300 def deduce_addresing (self , term ):
291301 if term .is_compute_constant :
@@ -323,7 +333,7 @@ def _get_tensorforge_matrix(self, tensor):
323333 return yi .gen_matrix (shape ,
324334 bboxrange ,
325335 addressing = addr_mode ,
326- name = tensor .name ,
336+ name = f' { self . _prefix } { tensor .name } ' ,
327337 is_tmp = tensor .is_temporary ,
328338 permute = None ,
329339 pattern = pattern ,
@@ -345,28 +355,35 @@ def _gen_call_site(self, generator):
345355 if matrix .is_tmp or matrix .addressing == Addressing .NONE :
346356 offset_name_map [name ] = '0'
347357 else :
348- offset_name_map [name ] = f'extraOffset_{ name } '
358+ parts = name .split ('.' )
359+ assert len (parts ) <= 2
360+ varname = f'extraOffset_{ parts [- 1 ]} '
361+ if len (parts ) == 2 :
362+ offset_name_map [name ] = f'{ parts [0 ]} .{ varname } '
363+ else :
364+ offset_name_map [name ] = varname
349365
350366 return generator .generate_call_site (mat_name_map ,
351- offset_name_map ,
352- 'numElements' ,
353- 'flags' ,
354- 'streamPtr' )
367+ offset_name_map )
355368
356369 def _append_operation (self , op ):
357370 if isinstance (op , (float , int )):
358371 return Tensor ([], Addressing .SCALAR , data = op )
359372 elif self .is_scalar (op ):
360- return self ._cache [op .name ()]
373+ return self ._cache [f' { self . _prefix } { op .name ()} ' ]
361374 else :
362- return self ._cache [op .name ]
375+ return self ._cache [f' { self . _prefix } { op .name } ' ]
363376
364377 def switch_region (self , barrier ):
365378 if barrier :
366379 self ._descr_list += [GridBarrierDescr ()]
367380 else :
368381 self ._descr_list += [GridFenceDescr ()]
369382
383+ def set_region_name (self , name ):
384+ self ._prefix = f"{ name } ."
385+ self ._descr_list += [RegionDescription (name )]
386+
370387class TensorForgeWriter :
371388 def __init__ (self , tensorforge_generator , headers ):
372389 self ._headers = list (headers ) + list (tensorforge_generator .get_helper_headers ())
@@ -410,6 +427,9 @@ def region_switch(self, barrier):
410427 self .generator .switch_region (barrier )
411428 return 0
412429
430+ def set_region_name (self , name ):
431+ self .generator .set_region_name (name )
432+
413433 def add_operation (self , description ):
414434 return self .generator .add_operation_new (description )
415435
0 commit comments