Skip to content

Commit e5df7e6

Browse files
authored
fix branches in amd_asm_matmul (tinygrad#14369)
1 parent 0ced258 commit e5df7e6

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

extra/assembly/amd/dsl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def set(self, raw: int, val) -> int:
116116
def __get__(self, obj, objtype=None):
117117
if obj is None: return self
118118
return self.decode((obj._raw >> self.lo) & self.mask)
119+
def __set__(self, obj, val): obj._raw = self.set(obj._raw, val)
119120

120121
class FixedBitField(BitField):
121122
def set(self, raw: int, val=None) -> int:

extra/gemm/amd_asm_matmul.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,14 @@ def target_ab(i):
168168
# =============================================================================
169169

170170
class Kernel:
171-
def __init__(self, arch='gfx1100'):
172-
self.instructions, self.labels, self.branch_targets, self.arch = [], {}, {}, arch
171+
def __init__(self, arch='gfx1100'): self.instructions, self.labels, self.pos, self.arch = [], {}, 0, arch
172+
def label(self, name): self.labels[name] = self.pos
173173

174-
def emit(self, inst): self.instructions.append(inst); return inst
175-
def label(self, name): self.labels[name] = len(self.instructions)
176-
def branch_to(self, label): self.branch_targets[len(self.instructions) - 1] = label
174+
def emit(self, inst, target=None):
175+
self.instructions.append(inst)
176+
inst._target, inst._pos = target, self.pos
177+
self.pos += inst.size()
178+
return inst
177179

178180
def waitcnt(self, lgkm=None, vm=None):
179181
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
@@ -182,16 +184,15 @@ def waitcnt(self, lgkm=None, vm=None):
182184
self.emit(s_waitcnt(simm16=waitcnt))
183185

184186
def to_asm(self):
185-
import re
186-
# Instruction stream with labels
187-
label_at = {pos: name for name, pos in self.labels.items()}
188-
body = []
189-
for i, inst in enumerate(self.instructions):
190-
if i in label_at: body.append(f'.{label_at[i]}:')
191-
asm = inst.disasm()
192-
if i in self.branch_targets:
193-
asm = re.sub(r'(s_cbranch_\w+|s_branch)\s+\S+', rf'\1 .{self.branch_targets[i]}', asm)
194-
body.append('\t' + asm)
187+
# Patch branch offsets: simm16 = (target_pos - branch_end_pos) / 4
188+
for inst in self.instructions:
189+
if inst._target is None: continue
190+
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
191+
if not -32768 <= offset_dwords <= 32767: raise ValueError(f"branch to '{inst._target}' offset {offset_dwords} exceeds simm16 range")
192+
inst.simm16 = offset_dwords
193+
194+
# TODO: replace this with direct ELF
195+
body = ['\t' + inst.disasm() for inst in self.instructions]
195196

196197
# limit wave occupancy by using more LDS
197198
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
@@ -315,7 +316,7 @@ def build_kernel(arch='gfx1100'):
315316
k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8))
316317

317318
# S_LOOP_CTR is already 0 from prologue initialization
318-
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_ENTRY')
319+
k.emit(s_branch(), target='LOOP_ENTRY')
319320

320321
# ===========================================================================
321322
# MAIN GEMM LOOP
@@ -326,12 +327,12 @@ def build_kernel(arch='gfx1100'):
326327
k.label('LOOP_INC')
327328
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
328329
k.emit(s_cmp_ge_i32(s[S_LOOP_CTR], s[S_DIM_N]))
329-
k.emit(s_cbranch_scc1(simm16=0)); k.branch_to('EPILOGUE')
330+
k.emit(s_cbranch_scc1(), target='EPILOGUE')
330331

331332
k.label('LOOP_ENTRY')
332333
k.emit(s_cmp_lt_i32(s[S_LOOP_CTR], s[S_LOOP_BOUND]))
333334
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
334-
k.emit(s_cbranch_scc0(simm16=0)); k.branch_to('SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
335+
k.emit(s_cbranch_scc0(), target='SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
335336

336337
if not NO_GLOBAL:
337338
# Advance prefetch pointers (VGPR)
@@ -402,7 +403,7 @@ def build_kernel(arch='gfx1100'):
402403
offset = i * 64
403404
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
404405

405-
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_INC')
406+
k.emit(s_branch(), target='LOOP_INC')
406407

407408
# ===========================================================================
408409
# EPILOGUE: Permute and store results

0 commit comments

Comments
 (0)