1515from .primitives import nvidia as nvidia
1616from .primitives import amd as amd
1717
18+ from copy import copy
19+
1820class MultilinearInstruction (ComputeInstruction ):
1921 def __init__ (self ,
2022 context : Context ,
2123 dest : Symbol ,
2224 ops : List [SymbolView ],
2325 target : List [List [int ]],
2426 prev : Union [None , Symbol ],
27+ next : Union [None , Symbol ],
2528 productOperation : ReductionOperator ,
2629 sumOperation : ReductionOperator ,
2730 prefer_align : bool ,
@@ -40,6 +43,7 @@ def __init__(self,
4043 self ._num_threads = num_threads
4144 self ._blockcount = blockcount
4245 self ._prev = prev
46+ self ._next = next
4347
4448 assert num_threads % blockcount == 0
4549
@@ -120,12 +124,18 @@ def _analyze(self):
120124 # i.e.: what can be loaded in early/late, do
121125
122126 # TODO: handle offsets
127+
128+ self ._idest = copy (self ._dest )
129+ self ._idest .data_view = DataView (shape = [u - l for l ,u in self ._ns ], permute = [i for i in range (targetrank )])
130+ self ._idest .data_view ._bbox ._lower = [l for l ,_ in self ._ns ]
131+ self ._idest .data_view ._bbox ._upper = [u for _ ,u in self ._ns ]
132+
123133 if self ._prev is not None :
124134 self ._dest .data_view = self ._prev .data_view
135+ if self ._next is not None :
136+ self ._dest .data_view = self ._next .data_view
125137 if self ._dest .data_view is None :
126- self ._dest .data_view = DataView (shape = [u - l for l ,u in self ._ns ], permute = [i for i in range (targetrank )])
127- self ._dest .data_view ._bbox ._lower = [l for l ,_ in self ._ns ]
128- self ._dest .data_view ._bbox ._upper = [u for _ ,u in self ._ns ]
138+ self ._dest .data_view = self ._idest .data_view
129139
130140 self ._lead_dims = [0 ]#[t for t in self._target[0] if t >= 0]
131141
@@ -201,9 +211,9 @@ def nonlead_writer(varlist):
201211 if len (self ._ops ) > 0 and len (prod ) == len (self ._ops ):
202212 for p in prod :
203213 writer (p )
204- self ._dest .load (writer , self ._context , 'value' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
214+ self ._idest .load (writer , self ._context , 'value' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
205215 writer (f'{ self ._fp_as_str } newvalue = { self ._sumOperation .format ("value" , f"prod{ len (self ._ops )- 1 } " )} ;' )
206- self ._dest .store (writer , self ._context , 'newvalue' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
216+ self ._idest .store (writer , self ._context , 'newvalue' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
207217
208218 write_loops (self ._context , writer , loopstack , nonlead_writer )
209219
@@ -280,7 +290,7 @@ def unwindOp(i, j, k, opid, full):
280290 return idx
281291
282292 def C (writer , var , i , j ):
283- self ._dest .store (writer , self ._context , var , unwindOp (i , j , 0 , None , False ), False )
293+ self ._idest .store (writer , self ._context , var , unwindOp (i , j , 0 , None , False ), False )
284294
285295 if self ._ops [1 ].symbol .obj and (not self ._ops [1 ].symbol .obj .is_dense () or self ._ops [1 ].symbol .data_view .shape [0 ] < 16 ):
286296 def sparse (k , j ):
@@ -306,9 +316,9 @@ def A(writer, var, i, k):
306316 return res
307317
308318 if self ._context .get_vm ().get_hw_descr ().vendor == 'amd' :
309- amd .matmul (writer , C , A , B , M , N , K , kx , self ._num_threads , self ._dest .datatype , sparse , self ._context )
319+ amd .matmul (writer , C , A , B , M , N , K , kx , self ._num_threads , self ._idest .datatype , sparse , self ._context )
310320 elif self ._context .get_vm ().get_hw_descr ().vendor == 'nvidia' :
311- return nvidia .matmul (writer , C , A , B , Mx , N , K , kx , self ._num_threads , self ._dest .datatype , sparse , self ._context , 'tempShrMem' , self .temp_shmem ())
321+ return nvidia .matmul (writer , C , A , B , Mx , N , K , kx , self ._num_threads , self ._idest .datatype , sparse , self ._context , 'tempShrMem' , self .temp_shmem ())
312322 return True
313323 return False
314324
@@ -422,9 +432,9 @@ def nonlead_writer(varlist):
422432 if prodc == len (self ._ops ):
423433 for prod in prods :
424434 writer (prod )
425- self ._dest .load (writer , self ._context , 'value' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
435+ self ._idest .load (writer , self ._context , 'value' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
426436 writer (f'{ self ._fp_as_str } newvalue = { self ._sumOperation .format ("value" , f"prod{ prodc - 1 } " )} ;' )
427- self ._dest .store (writer , self ._context , 'newvalue' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
437+ self ._idest .store (writer , self ._context , 'newvalue' , [varlist [loopmap [f'n{ i } ' ]] for i ,_ in enumerate (self ._ns )], False )
428438
429439 write_loops (self ._context , writer , loopstack , nonlead_writer )
430440
@@ -517,13 +527,13 @@ def _leading_dim(self, writer: Writer):
517527 loop .__enter__ ()
518528 loopstack += [loop ]
519529
520- self ._dest .load (writer , self ._context , 'value' , [self ._vm .get_lexic ().thread_idx_x ] + [f'n{ i + 1 } ' for i ,_ in enumerate (self ._ns [1 :])], False )
530+ self ._idest .load (writer , self ._context , 'value' , [self ._vm .get_lexic ().thread_idx_x ] + [f'n{ i + 1 } ' for i ,_ in enumerate (self ._ns [1 :])], False )
521531 #writer(f'auto* shmAddr = &{self._shr_mem.name}[{self._shr_mem_offset}];')
522532 self ._reduction (writer )
523533 write (f'value = tensorforge::reduction<tensorforge::ReductionOperation<{ self ._fp_as_str } , tensorforge::Op::Sum>, { self ._num_threads } , 1, { self ._fp_as_str } >(value);' )
524534 # self._butterfly_reduction_loop(writer, max_array_length = 32, amd = False)
525535 #writer(f'{self._fp_as_str} newvalue = shmAddr[{sublane_address}];')
526- self ._dest .store (writer , self ._context , 'value' , [self ._vm .get_lexic ().thread_idx_x ] + [f'n{ i + 1 } ' for i ,_ in enumerate (self ._ns [1 :])], False )
536+ self ._idest .store (writer , self ._context , 'value' , [self ._vm .get_lexic ().thread_idx_x ] + [f'n{ i + 1 } ' for i ,_ in enumerate (self ._ns [1 :])], False )
527537
528538 for loop in loopstack [::- 1 ]:
529539 loop .__exit__ (None , None , None )
0 commit comments