Skip to content

Commit 19f933e

Browse files
committed
Fix more long-kernel bugs
1 parent bc4dcdc commit 19f933e

File tree

5 files changed

+39
-15
lines changed

5 files changed

+39
-15
lines changed

tensorforge/backend/instructions/builders/multilinear_builder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self,
4242
self._dest_regs = None
4343

4444
self._use_registers_always = self._context.get_vm().get_hw_descr().vendor in ['amd']
45-
self._preload_registers = False
45+
self._preload_registers = False #self._context.get_vm().get_hw_descr().vendor in ['amd']
4646
self._deferred_stores = {}
4747
self._temporaries = {}
4848

@@ -112,9 +112,10 @@ def _make_load_op(self, i):
112112
self._loaders_cache[self._mem_regions[i]] = load_op
113113
self._instructions.append(load_op)
114114
else:
115-
if self._preload_registers:
115+
if self._preload_registers and self._ops[i].symbol.obj.is_dense() and not (self._ops[i].symbol in self._loaders_cache.keys()):
116+
# only register-preload dense matrices for now
116117
self._mem_regions[i], load_op = self._make_loader_and_symbol_reg(self._ops[i].symbol, is_transpose=self._descr.permute[i])
117-
self._loaders_cache[self._mem_regions[i]] = load_op
118+
self._loaders_cache[self._ops[i].symbol] = load_op
118119
self._instructions.append(load_op)
119120
else:
120121
# Note: operand will reside in glb. mem for gemm operation
@@ -156,7 +157,7 @@ def _make_loader_and_symbol_reg(self, operand, is_transpose) -> Tuple[Symbol, Gl
156157
threads = self._num_threads
157158
lead_dim = [0] # [t for t in self._descr.target[0] if t >= 0]
158159

159-
for d, dim in enumerate(operand.bbox.sizes()):
160+
for d, dim in enumerate(operand.data_view._bbox.sizes()):
160161
if d not in lead_dim or threads == 0:
161162
regsize *= dim
162163
else:
@@ -251,7 +252,8 @@ def _make_store(self):
251252
# dest=dest_symbol,
252253
# shr_mem=self._shr_mem,
253254
# num_threads=self._num_threads))
254-
pass # see note below
255+
# see note below (but update to the new temp regs)
256+
self._deferred_stores[dest_symbol.name] = (self._temp_regs, dest_symbol)
255257
elif dest_symbol.stype == SymbolType.Global:
256258
if self._use_registers_always:
257259
self._deferred_stores[dest_symbol.name] = (self._temp_regs, dest_symbol)

tensorforge/backend/instructions/compute/multilinear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from tensorforge.common.basic_types import Datatype
1111
from tensorforge.backend.writer import Writer
1212

13+
from tensorforge.common.matrix.tensor import Tensor
14+
1315
from .primitives import nvidia as nvidia
1416
from .primitives import amd as amd
1517

@@ -80,8 +82,11 @@ def _analyze(self):
8082
opdim = [''] * op.bbox.rank()
8183
for j in range(op.bbox.rank()):
8284
# TODO: check adding the data_view box here again
83-
lower = op.bbox.lower()[j] #+ op.symbol.data_view._bbox.lower()[j]
84-
upper = op.bbox.upper()[j] #+ op.symbol.data_view._bbox.lower()[j]
85+
lower = op.bbox.lower()[j] + op.symbol.data_view._bbox.lower()[j]
86+
upper = op.bbox.upper()[j] + op.symbol.data_view._bbox.lower()[j]
87+
if op.symbol.obj and isinstance(op.symbol.obj, Tensor):
88+
lower -= op.symbol.obj.get_bbox().lower()[j]
89+
upper -= op.symbol.obj.get_bbox().lower()[j]
8590
#if self._target[i][j] != 0:
8691
# lower -= op.offset[j]
8792
# upper -= op.offset[j]

tensorforge/backend/instructions/compute/primitives/nvidia.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,16 @@ def threadrange(start, size):
486486
with writer.AnonymousScope():
487487
writer('__syncwarp();')
488488
with threadrange(ii, atom.m):
489-
for kkk in range(0, atom.k):
490-
writer(f'{shmptr}[{aoffs} + (threadIdx.x - {ii}) % {atom.m} + {kkk * atom.m}] = {Areg}_{kkk};')
489+
# for kkk in range(0, atom.k):
490+
# writer(f'{shmptr}[{aoffs} + (threadIdx.x - {ii}) % {atom.m} + {kkk * atom.m}] = {Areg}_{kkk};')
491+
for kkk in range(0, atom.k, ktile):
492+
writer(f'*(float4*)&{shmptr}[{aoffs} + ((threadIdx.x - {ii}) % {atom.m}) * {ktile} + {kkk * atom.m}] = make_float4({Areg}_{kkk}, {Areg}_{kkk+1}, {Areg}_{kkk+2}, {Areg}_{kkk+3});')
491493
writer('__syncwarp();')
492494

493495
for kk in range(0, kregs):
494496
for iii in range(0, mregs):
495-
writer(f'{atom.d.ctype()} {Areg2}_{iii + kk * mregs} = {shmptr}[{aoffs} + (threadIdx.x / {ktile}) + (threadIdx.x % {ktile} + {kk * ktile}) * {atom.m} + {iii * mtile}];')
497+
#writer(f'{atom.d.ctype()} {Areg2}_{iii + kk * mregs} = {shmptr}[{aoffs} + (threadIdx.x / {ktile}) + (threadIdx.x % {ktile} + {kk * ktile}) * {atom.m} + {iii * mtile}];')
498+
writer(f'{atom.d.ctype()} {Areg2}_{iii + kk * mregs} = {shmptr}[{aoffs} + threadIdx.x + {(iii + kk * mregs) * 32}];')
496499

497500
atom.generate(writer, ctx, [f'{Areg2}_{i}' for i in range (aregs)], [f'{Breg2}_{i}' for i in range (bregs)], [f'{Creg}[{i}][{ii // atom.m}]' for i in range (cregs)])
498501

tensorforge/backend/instructions/memory/store.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tensorforge.common.matrix.tensor import Tensor
44
from tensorforge.common.matrix.boundingbox import BoundingBox
55
from tensorforge.backend.data_types import RegMemObject
6-
from tensorforge.backend.symbol import Symbol, SymbolType, DataView, LeadIndex, write_loops, LeadLoop, Loop
6+
from tensorforge.backend.symbol import Symbol, SymbolType, DataView, LeadIndex, write_loops, LeadLoop, Loop, Immediate
77
from tensorforge.common.exceptions import InternalError
88
from tensorforge.backend.writer import Writer
99
from . import AbstractShrMemWrite, MemoryInstruction
@@ -175,15 +175,25 @@ def gen_code(self, writer: Writer) -> None:
175175

176176
writer(f'// {self}')
177177
src_bbox = self._src.data_view.get_bbox()
178+
dest_bbox = self._dest.data_view.get_bbox()
178179
with writer.Scope():
180+
manual = [False]
179181
loops = []
180182
loops += [LeadLoop('i0', src_bbox.lower()[0], src_bbox.upper()[0], self._num_threads, 1)]
181183
for i in range(1, src_bbox.rank()):
182-
loops += [Loop(f'i{i}', src_bbox.lower()[i], src_bbox.upper()[i], 1)]
184+
unroll = (src_bbox.lower()[i], src_bbox.upper()[i]) != (dest_bbox.lower()[i], dest_bbox.upper()[i])
185+
lower = min(src_bbox.lower()[i], dest_bbox.lower()[i])
186+
upper = max(src_bbox.upper()[i], dest_bbox.upper()[i])
187+
loops += [Loop(f'i{i}', lower, upper, 1, unroll)]
188+
manual += [unroll]
183189

184190
def inner(indices):
185-
self._src.load(writer, self._context, 'value', indices, False)
186-
self._dest.store(writer, self._context, 'value', indices, allow_nontemporal)
191+
needsLoad = all(not isinstance(index, Immediate) or (src_bbox.lower()[i] <= index._value and src_bbox.upper()[i] > index._value) for i,index in enumerate(indices))
192+
if needsLoad:
193+
self._src.load(writer, self._context, 'value', indices, False)
194+
self._dest.store(writer, self._context, 'value', indices, allow_nontemporal)
195+
else:
196+
self._dest.store(writer, self._context, '0', indices, allow_nontemporal)
187197

188198
write_loops(self._context, writer, loops, inner)
189199

tensorforge/generators/generator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,13 @@ def _emit_ir(self, descr_list):
360360
self._num_threads)
361361
# builder.build_prologue()
362362

363+
def get_symbol_view(op):
364+
symbol = self._scopes.get_symbol(op.tensor)
365+
return SymbolView(symbol, op.bbox, op.offset)
366+
363367
for gemm_descr in descr_list:
364368
if isinstance(gemm_descr, MultilinearDescr):
365-
builder.build(ops=[SymbolView(self._scopes.get_symbol(op.tensor), op.bbox, op.offset) for op in gemm_descr.ops],
369+
builder.build(ops=[get_symbol_view(op) for op in gemm_descr.ops],
366370
dest_obj=gemm_descr.dest,
367371
descr=gemm_descr)
368372
self._section.ir.extend(builder.get_instructions())

0 commit comments

Comments
 (0)