Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e77420a
codegen: BAddrInterleave
jfactory07 Jan 7, 2026
cec7d49
codegen: implement align-k
jfactory07 Jan 9, 2026
5ef3eb7
refine
jfactory07 Jan 9, 2026
448c4d6
refine macro
jfactory07 Jan 12, 2026
c049b0c
codegen: do NOT overwrite the original stride SGPRs in-place.
jfactory07 Jan 12, 2026
9003fda
codeGen : refine default value
jfactory07 Jan 13, 2026
8b4554d
codegen : refine get
jfactory07 Jan 13, 2026
b3a42a7
host restriction: If n divided by MT1 is not a power of two, address …
jfactory07 Jan 13, 2026
63d386a
host restriction: change to :
jfactory07 Jan 13, 2026
74c6e7a
codegen: remove BInterleaveG guard from kernel's runtime
jfactory07 Jan 13, 2026
e682def
host restriction: add AssertKRingShiftAlignedK
jfactory07 Jan 13, 2026
238fdc7
add: AssertKRingShiftTailWrapOnly
jfactory07 Jan 15, 2026
10b03bb
codegen: shift = (-baseOffsetElems) mod cacheLineElements
jfactory07 Jan 16, 2026
2f6cc19
refine macro
jfactory07 Jan 19, 2026
e871991
codegen: refine tail for krs
jfactory07 Jan 20, 2026
c1c0aa7
tailStartChunk = ceil(KRingShift / chunkElems)
jfactory07 Jan 20, 2026
0e59506
fix error
jfactory07 Jan 20, 2026
51f1fe7
clean code
jfactory07 Jan 20, 2026
0ecb0db
clean code
jfactory07 Jan 20, 2026
60e14e4
clean code
jfactory07 Jan 20, 2026
db885c5
refine restriction
jfactory07 Jan 20, 2026
b718bda
refine comments
jfactory07 Jan 21, 2026
67b5d1a
enable
jfactory07 Jan 21, 2026
11dfdd0
add test
jfactory07 Jan 22, 2026
b6a3b8f
add test case
jfactory07 Jan 22, 2026
be6e173
Merge branch 'develop' into users/jzhou/address-interleave
jfactory07 Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ def makeValidMatrixInstructions():
"DirectToVgprA": [False, True],
"DirectToVgprB": [False, True],
"DirectToVgprSparseMetadata": [False, True],
# B address interleave (restricted): non-contiguous tile columns for TN/NN-like B (TLUB == False),
# with runtime G chosen as the largest power-of-two factor of (N/MT1), capped by LVCB.
# Requires SizeJ % MT1 == 0 at runtime; otherwise falls back to original mapping.
"BAddrInterleave": [False, True],

# K ring-shift (restricted): apply a per-WG shift along the summation (K) dimension so that
# the B-side base K address for each workgroup is cacheline-aligned/congruent, while preserving
# correctness via full-loop ring wrap. Intended for TN/NN-like B (TLUB == False).
"KRingShift": [False, True],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you add the default value of these two new parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, their default values are currently set to false.

# Attempt to load directly from global memory into LDS.
# Assembly only
# Requires BufferLoad, assembler support for lds modifier on buffer
Expand Down
111 changes: 101 additions & 10 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
DSLoadU8, DSStore2B32, DSStore2B64, DSStoreB128, DSStoreB16, DSStoreB256, \
DSStoreB32, DSStoreB64, DSStoreB8, DSStoreInstruction, FlatLoadB128, FlatLoadB32, \
FlatLoadB64, FlatStoreB128, FlatStoreB32, FlatStoreB64, Instruction, MacroInstruction, \
MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpLeU32, \
MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpEQU32, SCmpLeU32, \
SMFMAInstruction, SNop, SSetPrior, SSetRegIMM32B32, SSubU32, SWaitCnt, SWaitAlu, \
SLongBranchPositive, VFmaMixF32, VMadMixF32, VMovB32
from rocisa.register import RegisterPool
Expand Down Expand Up @@ -3223,6 +3223,8 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.addComment1("remove stagger offsets for tail loop")
module.add(self.removeStagger(kernel, tensorParametersA))
module.add(self.removeStagger(kernel, tensorParametersB))
# KRS: Tail offset patching is now emitted just-in-time immediately before each tail global read,
# to allow instruction interleaving (apply -> load) and avoid a large apply-only block here.

# if swapGlobalRoad is true, swap the order of global read (B->A)
tensorParameters1st = tensorParametersA
Expand Down Expand Up @@ -3271,16 +3273,100 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.addComment1("Tail global read %s"%tc1)
if tailLoopOpt1st and (globalReadMode1st == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
# Keep per-tensor tail branching for tc2 when tc1 uses tailLoopOpt.
if kernel.get("KRingShift", False) and kernel["BufferLoad"] and tc2 in ("A", "B"):
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
else:
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
# KRS: If both tail global-read blocks (A/B) are eligible for KRS, do ONE runtime branch and
# share ONE set of labels for A/B. When sgprKRingShift==0, force both tail blocks down the
# original "load-only" path (no KRS_TAIL_OFFSET_* at all).
krsTailBranchable1 = kernel.get("KRingShift", False) and kernel["BufferLoad"] and tc1 in ("A", "B")
krsTailBranchable2 = kernel.get("KRingShift", False) and kernel["BufferLoad"] and tc2 in ("A", "B") \
and not (tailLoopOpt2nd and (globalReadMode2nd == 2))
if krsTailBranchable1 and krsTailBranchable2:
labelNoKRS = Label(self.labels.getNameInc("KRS_tail_noop_AB"), "")
labelDoneKRS = Label(self.labels.getNameInc("KRS_tail_done_AB"), "")
labelNoKRS.comment = "KRS: tail no-KRS path for A/B (sgprKRingShift==0)"
labelDoneKRS.comment = "KRS: tail KRS branch join for A/B"

module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads (A+B)"))

# KRS-enabled path: A then B
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads (A+B)"))

# no-KRS path: A then B (load-only)
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
# Fallback: keep per-tensor tail branching.
if krsTailBranchable1:
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc1}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc1}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc1} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc1}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))

module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
if kernel.get("KRingShift", False) and kernel["BufferLoad"] and tc2 in ("A", "B"):
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))

doA = False
doB = False
Expand Down Expand Up @@ -3446,6 +3532,11 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
if item[0] != -1:
module.add(self.tailLoopFreeVgpr(item[0], item[1]))

# KRS: tail is finished; sgprKRingShift must not be remapped (e.g. to 0) and can be released now.
# Emit an explicit UNDEF here so it lands right after the tail VALU vgpr UNDEF block.
if kernel.get("KRingShift", False) and kernel["BufferLoad"]:
module.add(TextBlock(".set sgprKRingShift, UNDEF\n"))

# Check in VGPR for DTV
for item in vDtvResources:
if item[0] != -1:
Expand Down
Loading