Skip to content

Commit 9284e55

Browse files
committed
Fix multilinear (and visco)
1 parent be4b5e9 commit 9284e55

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

tensorforge/backend/instructions/builders/multilinear_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def _alloc_register_array(self):
232232

233233
def _get_target_symbol(self, prev=False):
234234
dest_symbol = self._scopes.get_symbol(self._dest_obj.tensor)
235+
if dest_symbol is None:
236+
return None
235237
if dest_symbol.name in self._deferred_stores:
236238
dest_registers,_,_ = self._deferred_stores[dest_symbol.name]
237239
return dest_registers
@@ -253,6 +255,7 @@ def _make_compute(self):
253255
prefer_align=False,#self._descr.prefer_align,
254256
num_threads=self._num_threads,
255257
prev=self._get_target_symbol(True) if self._add else None,
258+
next=self._get_target_symbol(True),
256259
productOperation=MulOperator(),
257260
sumOperation=AddOperator()))
258261

tensorforge/backend/instructions/compute/multilinear.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
from .primitives import nvidia as nvidia
1616
from .primitives import amd as amd
1717

18+
from copy import copy
19+
1820
class MultilinearInstruction(ComputeInstruction):
1921
def __init__(self,
2022
context: Context,
2123
dest: Symbol,
2224
ops: List[SymbolView],
2325
target: List[List[int]],
2426
prev: Union[None, Symbol],
27+
next: Union[None, Symbol],
2528
productOperation: ReductionOperator,
2629
sumOperation: ReductionOperator,
2730
prefer_align: bool,
@@ -40,6 +43,7 @@ def __init__(self,
4043
self._num_threads = num_threads
4144
self._blockcount = blockcount
4245
self._prev = prev
46+
self._next = next
4347

4448
assert num_threads % blockcount == 0
4549

@@ -120,12 +124,18 @@ def _analyze(self):
120124
# i.e.: what can be loaded in early/late, do
121125

122126
# TODO: handle offsets
127+
128+
self._idest = copy(self._dest)
129+
self._idest.data_view = DataView(shape = [u - l for l,u in self._ns], permute=[i for i in range(targetrank)])
130+
self._idest.data_view._bbox._lower = [l for l,_ in self._ns]
131+
self._idest.data_view._bbox._upper = [u for _,u in self._ns]
132+
123133
if self._prev is not None:
124134
self._dest.data_view = self._prev.data_view
135+
if self._next is not None:
136+
self._dest.data_view = self._next.data_view
125137
if self._dest.data_view is None:
126-
self._dest.data_view = DataView(shape = [u - l for l,u in self._ns], permute=[i for i in range(targetrank)])
127-
self._dest.data_view._bbox._lower = [l for l,_ in self._ns]
128-
self._dest.data_view._bbox._upper = [u for _,u in self._ns]
138+
self._dest.data_view = self._idest.data_view
129139

130140
self._lead_dims = [0]#[t for t in self._target[0] if t >= 0]
131141

@@ -201,9 +211,9 @@ def nonlead_writer(varlist):
201211
if len(self._ops) > 0 and len(prod) == len(self._ops):
202212
for p in prod:
203213
writer(p)
204-
self._dest.load(writer, self._context, 'value', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
214+
self._idest.load(writer, self._context, 'value', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
205215
writer(f'{self._fp_as_str} newvalue = {self._sumOperation.format("value", f"prod{len(self._ops)-1}")};')
206-
self._dest.store(writer, self._context, 'newvalue', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
216+
self._idest.store(writer, self._context, 'newvalue', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
207217

208218
write_loops(self._context, writer, loopstack, nonlead_writer)
209219

@@ -280,7 +290,7 @@ def unwindOp(i, j, k, opid, full):
280290
return idx
281291

282292
def C(writer, var, i, j):
283-
self._dest.store(writer, self._context, var, unwindOp(i, j, 0, None, False), False)
293+
self._idest.store(writer, self._context, var, unwindOp(i, j, 0, None, False), False)
284294

285295
if self._ops[1].symbol.obj and (not self._ops[1].symbol.obj.is_dense() or self._ops[1].symbol.data_view.shape[0] < 16):
286296
def sparse(k, j):
@@ -306,9 +316,9 @@ def A(writer, var, i, k):
306316
return res
307317

308318
if self._context.get_vm().get_hw_descr().vendor == 'amd':
309-
amd.matmul(writer, C, A, B, M, N, K, kx, self._num_threads, self._dest.datatype, sparse, self._context)
319+
amd.matmul(writer, C, A, B, M, N, K, kx, self._num_threads, self._idest.datatype, sparse, self._context)
310320
elif self._context.get_vm().get_hw_descr().vendor == 'nvidia':
311-
return nvidia.matmul(writer, C, A, B, Mx, N, K, kx, self._num_threads, self._dest.datatype, sparse, self._context, 'tempShrMem', self.temp_shmem())
321+
return nvidia.matmul(writer, C, A, B, Mx, N, K, kx, self._num_threads, self._idest.datatype, sparse, self._context, 'tempShrMem', self.temp_shmem())
312322
return True
313323
return False
314324

@@ -422,9 +432,9 @@ def nonlead_writer(varlist):
422432
if prodc == len(self._ops):
423433
for prod in prods:
424434
writer(prod)
425-
self._dest.load(writer, self._context, 'value', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
435+
self._idest.load(writer, self._context, 'value', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
426436
writer(f'{self._fp_as_str} newvalue = {self._sumOperation.format("value", f"prod{prodc - 1}")};')
427-
self._dest.store(writer, self._context, 'newvalue', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
437+
self._idest.store(writer, self._context, 'newvalue', [varlist[loopmap[f'n{i}']] for i,_ in enumerate(self._ns)], False)
428438

429439
write_loops(self._context, writer, loopstack, nonlead_writer)
430440

@@ -517,13 +527,13 @@ def _leading_dim(self, writer: Writer):
517527
loop.__enter__()
518528
loopstack += [loop]
519529

520-
self._dest.load(writer, self._context, 'value', [self._vm.get_lexic().thread_idx_x] + [f'n{i+1}' for i,_ in enumerate(self._ns[1:])], False)
530+
self._idest.load(writer, self._context, 'value', [self._vm.get_lexic().thread_idx_x] + [f'n{i+1}' for i,_ in enumerate(self._ns[1:])], False)
521531
#writer(f'auto* shmAddr = &{self._shr_mem.name}[{self._shr_mem_offset}];')
522532
self._reduction(writer)
523533
write(f'value = tensorforge::reduction<tensorforge::ReductionOperation<{self._fp_as_str}, tensorforge::Op::Sum>, {self._num_threads}, 1, {self._fp_as_str}>(value);')
524534
# self._butterfly_reduction_loop(writer, max_array_length = 32, amd = False)
525535
#writer(f'{self._fp_as_str} newvalue = shmAddr[{sublane_address}];')
526-
self._dest.store(writer, self._context, 'value', [self._vm.get_lexic().thread_idx_x] + [f'n{i+1}' for i,_ in enumerate(self._ns[1:])], False)
536+
self._idest.store(writer, self._context, 'value', [self._vm.get_lexic().thread_idx_x] + [f'n{i+1}' for i,_ in enumerate(self._ns[1:])], False)
527537

528538
for loop in loopstack[::-1]:
529539
loop.__exit__(None, None, None)

0 commit comments

Comments
 (0)