@@ -168,12 +168,14 @@ def target_ab(i):
168168# =============================================================================
169169
170170class Kernel :
171- def __init__ (self , arch = 'gfx1100' ):
172- self . instructions , self . labels , self .branch_targets , self . arch = [], {}, {}, arch
171+ def __init__ (self , arch = 'gfx1100' ): self . instructions , self . labels , self . pos , self . arch = [], {}, 0 , arch
172+ def label ( self , name ): self .labels [ name ] = self . pos
173173
174- def emit (self , inst ): self .instructions .append (inst ); return inst
175- def label (self , name ): self .labels [name ] = len (self .instructions )
176- def branch_to (self , label ): self .branch_targets [len (self .instructions ) - 1 ] = label
174+ def emit (self , inst , target = None ):
175+ self .instructions .append (inst )
176+ inst ._target , inst ._pos = target , self .pos
177+ self .pos += inst .size ()
178+ return inst
177179
178180 def waitcnt (self , lgkm = None , vm = None ):
179181 """Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
@@ -182,16 +184,15 @@ def waitcnt(self, lgkm=None, vm=None):
182184 self .emit (s_waitcnt (simm16 = waitcnt ))
183185
184186 def to_asm (self ):
185- import re
186- # Instruction stream with labels
187- label_at = {pos : name for name , pos in self .labels .items ()}
188- body = []
189- for i , inst in enumerate (self .instructions ):
190- if i in label_at : body .append (f'.{ label_at [i ]} :' )
191- asm = inst .disasm ()
192- if i in self .branch_targets :
193- asm = re .sub (r'(s_cbranch_\w+|s_branch)\s+\S+' , rf'\1 .{ self .branch_targets [i ]} ' , asm )
194- body .append ('\t ' + asm )
187+ # Patch branch offsets: simm16 = (target_pos - branch_end_pos) / 4
188+ for inst in self .instructions :
189+ if inst ._target is None : continue
190+ offset_dwords = (self .labels [inst ._target ] - inst ._pos - inst .size ()) // 4
191+ if not - 32768 <= offset_dwords <= 32767 : raise ValueError (f"branch to '{ inst ._target } ' offset { offset_dwords } exceeds simm16 range" )
192+ inst .simm16 = offset_dwords
193+
194+ # TODO: replace this with direct ELF
195+ body = ['\t ' + inst .disasm () for inst in self .instructions ]
195196
196197 # limit wave occupancy by using more LDS
197198 lds_size = max (LDS_SIZE , 65536 // getenv ("LIMIT_OCC" , 65536 ))
@@ -315,7 +316,7 @@ def build_kernel(arch='gfx1100'):
315316 k .emit (s_add_i32 (s [S_LOOP_BOUND ], s [S_DIM_N ], - 8 ))
316317
317318 # S_LOOP_CTR is already 0 from prologue initialization
318- k .emit (s_branch (simm16 = 0 )); k . branch_to ( 'LOOP_ENTRY' )
319+ k .emit (s_branch (), target = 'LOOP_ENTRY' )
319320
320321 # ===========================================================================
321322 # MAIN GEMM LOOP
@@ -326,12 +327,12 @@ def build_kernel(arch='gfx1100'):
326327 k .label ('LOOP_INC' )
327328 k .emit (s_add_i32 (s [S_LOOP_CTR ], s [S_LOOP_CTR ], 8 ))
328329 k .emit (s_cmp_ge_i32 (s [S_LOOP_CTR ], s [S_DIM_N ]))
329- k .emit (s_cbranch_scc1 (simm16 = 0 )); k . branch_to ( 'EPILOGUE' )
330+ k .emit (s_cbranch_scc1 (), target = 'EPILOGUE' )
330331
331332 k .label ('LOOP_ENTRY' )
332333 k .emit (s_cmp_lt_i32 (s [S_LOOP_CTR ], s [S_LOOP_BOUND ]))
333334 k .emit (s_cselect_b32 (s [S_PREFETCH_FLAG ], - 1 , 0 )) # s_cselect doesn't modify SCC
334- k .emit (s_cbranch_scc0 (simm16 = 0 )); k . branch_to ( 'SKIP_PREFETCH' ) # branch if loop_ctr >= loop_bound
335+ k .emit (s_cbranch_scc0 (), target = 'SKIP_PREFETCH' ) # branch if loop_ctr >= loop_bound
335336
336337 if not NO_GLOBAL :
337338 # Advance prefetch pointers (VGPR)
@@ -402,7 +403,7 @@ def build_kernel(arch='gfx1100'):
402403 offset = i * 64
403404 k .emit (ds_store_b32 (addr = v [V_LDS_B_ADDR ], data0 = v [V_LDS_B_DATA [i ]], offset0 = offset & 0xFF , offset1 = offset >> 8 ))
404405
405- k .emit (s_branch (simm16 = 0 )); k . branch_to ( 'LOOP_INC' )
406+ k .emit (s_branch (), target = 'LOOP_INC' )
406407
407408 # ===========================================================================
408409 # EPILOGUE: Permute and store results
0 commit comments