Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e77420a
codegen: BAddrInterleave
jfactory07 Jan 7, 2026
cec7d49
codegen: implement align-k
jfactory07 Jan 9, 2026
5ef3eb7
refine
jfactory07 Jan 9, 2026
448c4d6
refine macro
jfactory07 Jan 12, 2026
c049b0c
codegen: do NOT overwrite the original stride SGPRs in-place.
jfactory07 Jan 12, 2026
9003fda
codeGen : refine default value
jfactory07 Jan 13, 2026
8b4554d
codegen : refine get
jfactory07 Jan 13, 2026
b3a42a7
host restriction: If n divided by MT1 is not a power of two, address …
jfactory07 Jan 13, 2026
63d386a
host restriction: change to :
jfactory07 Jan 13, 2026
74c6e7a
codegen: remove BInterleaveG guard from kernel's runtime
jfactory07 Jan 13, 2026
e682def
host restriction: add AssertKRingShiftAlignedK
jfactory07 Jan 13, 2026
238fdc7
add: AssertKRingShiftTailWrapOnly
jfactory07 Jan 15, 2026
10b03bb
codegen: shift = (-baseOffsetElems) mod cacheLineElements
jfactory07 Jan 16, 2026
2f6cc19
refine macro
jfactory07 Jan 19, 2026
e871991
codegen: refine tail for krs
jfactory07 Jan 20, 2026
c1c0aa7
tailStartChunk = ceil(KRingShift / chunkElems)
jfactory07 Jan 20, 2026
0e59506
fix error
jfactory07 Jan 20, 2026
51f1fe7
clean code
jfactory07 Jan 20, 2026
0ecb0db
clean code
jfactory07 Jan 20, 2026
60e14e4
clean code
jfactory07 Jan 20, 2026
db885c5
refine restriction
jfactory07 Jan 20, 2026
b718bda
refine comments
jfactory07 Jan 21, 2026
67b5d1a
enable
jfactory07 Jan 21, 2026
11dfdd0
add test
jfactory07 Jan 22, 2026
b6a3b8f
add test case
jfactory07 Jan 22, 2026
be6e173
Merge branch 'develop' into users/jzhou/address-interleave
jfactory07 Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347588,6 +347588,7 @@
AssertSummationElementMultiple: 1
AssignedDerivedParameters: true
AssignedProblemIndependentDerivedParameters: true
BAddrInterleave: true
BaseName: Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs_MT128x192x128_MI1n-ZRL8mxl6uWAXsQRkqoULpPWXdDeqIc8yirhD35fpI=
BufferLoad: true
BufferStore: true
Expand Down Expand Up @@ -347633,6 +347634,7 @@
Kernel: true
KernelLanguage: Assembly
KernelNameMin: Cijk_Alik_Bljk_BBS_BH_Bias_HA_S_SAV_UserArgs_MT128x192x128_MI16x16x1_SN_LDSB1_AFC0_AG0_AFEM1_AFEM1_ASEM1_CLR0_CADS0_DTLA0_DTLB0_DTVA0_DTVB0_EPS0_ELFLR0_EMLLn1_FDSI0_GRPM1_GRVWA8_GRVWB8_GSU0_GSUAMB_GLS0_ISA950_IU1_K1_LDSTI0_LBSPPA2048_LBSPPB256_LBSPPM0_LPA16_LPB16_LPM0_LRVW8_LWPMn1_MIAV0_MIWT8_3_MO40_NTn1_NTA0_NTB0_NTC1_NTD4_NTM0_NEPBS0_NLCA1_NLCB1_ONLL1_PGR2_PLR1_PKA1_SIA3_SS1_SPO0_SRVW0_SSO0_SVW8_SK3_SKFTR0_SKXCCM0_SGRO0_TLDS2_ULSGRO0_USL1_UIOFGRO0_USFGRO0_VSn1_VWA8_VWB1_WSGRA0_WSGRB0_WS64_WG16_16_1
KRingShift: true
LDSTrInst: false
LSCA: 128
LSCB: 128
Expand Down Expand Up @@ -347745,7 +347747,7 @@
SourceSwap: 1
SpaceFillingAlgo: []
StaggerU: 16
StaggerUMapping: 0
StaggerUMapping: 1
StaggerUStride: 256
StorePriorityOpt: 0
StoreRemapVectorWidth: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,22 @@
{"DirectToVgprA": [False]},
{"DirectToVgprB": [False]},
{"DirectToVgprSparseMetadata": [False]},
# Restricted address remap features (default off unless explicitly enabled in the solution config):
{"BAddrInterleave": [False]},
{"KRingShift": [False]},
{"DirectToLds": [0]},
{"UseSgprForGRO": [-1]},
{"UseInstOffsetForGRO": [0]},
{"AssertSummationElementMultiple": [1]},
{"AssertFree0ElementMultiple": [1]},
{"AssertFree1ElementMultiple": [1]},
# Address-interleave restriction (default disabled):
# When >0, the solution requires tiles1=(SizeJ/MT1) to have lowbit(tiles1)>1 (i.e. G>1),
{"AssertFree1DivByMT1LowbitGT1": [0]},
# KRingShift wrap restriction (default disabled):
# Encodes a runtime predicate that ensures (k + KRingShift) does not wrap in main loop
# (wrap is allowed only in tail loop where codegen applies the correction).
{"AssertKRingShiftTailWrapOnly": [0]},
{"AssertAIGreaterThanEqual": [-1]},
{"AssertAILessThanEqual": [-1]},
{"StaggerU": [32]}, # recommend [0,32]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ def makeValidMatrixInstructions():
"DirectToVgprA": [False, True],
"DirectToVgprB": [False, True],
"DirectToVgprSparseMetadata": [False, True],
# B address interleave (restricted): non-contiguous tile columns for TN/NN-like B (TLUB == False),
# with runtime G chosen as the largest power-of-two factor of (N/MT1), capped by LVCB.
# Requires SizeJ % MT1 == 0 at runtime; otherwise falls back to original mapping.
Copy link
Contributor

@aazz44ss aazz44ss Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is here mean "falls back to original mapping"?
Will it be rejected by predicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, the description here is inaccurate. After adding AssertFree1DivByMT1LowbitGT1 now, shapes that don't satisfy address interleaving will be rejected by predicate. I will correct it. thanks

"BAddrInterleave": [False, True],
# K ring-shift (restricted): apply a per-WG shift along the summation (K) dimension so that
# the B-side base K address for each workgroup is cacheline-aligned/congruent, while preserving
# correctness via tail-loop ring wrap. Intended for TN/NN-like B (TLUB == False).
"KRingShift": [False, True],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you add the default value of these two new parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, their default values are currently set to false.

# Attempt to load directly from global memory into LDS.
# Assembly only
# Requires BufferLoad, assembler support for lds modifier on buffer
Expand Down Expand Up @@ -434,6 +442,16 @@ def makeValidMatrixInstructions():
# - See above AssertFree0ElementMultiple "Load optimizations"
# 1 indicates no assertion (since all sizes are multiples of 1)
"AssertFree1ElementMultiple": [1, 2, 4, 8, 16],
# Address-interleave restriction:
# If >0, require tiles1=(Free1Size / MT1) to have lowbit(tiles1)>1 (i.e. G>1).
# This matches the kernel's initBInterleaveG logic:
# - require Free1Size % MT1 == 0
# - compute lowbit(tiles1)
# - enable only if min(lowbit, LVCB) > 1
"AssertFree1DivByMT1LowbitGT1": -1,
# KRingShift wrap restriction (packed integer; see Solution.py for encoding):
# If >0, require any (k + KRingShift) wrap to occur only in tail loop (no main-loop wrap).
"AssertKRingShiftTailWrapOnly": -1,
# Assertions that require arithmetic intensity to be specified value.
# Arithmetic intensity measures the ratio of computation to memory bandwidth required for a problem.
# These predicates can be used to adjust solution selection compute-bound or memory-bound problems.
Expand Down
10 changes: 10 additions & 0 deletions projects/hipblaslt/tensilelite/Tensile/Contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,16 @@ def FromOriginalKeyPair(cls, pair):
if key == "AssertAILessThanEqual":
return cls("AILessThanEqual", value=value) if value > 0 else None

# Address-interleave restriction:
# Require tiles1 = Free1Size / MT1 to be a power-of-two (and divisible).
if key == "AssertFree1DivByMT1LowbitGT1":
return cls("Free1SizeDivByValueLowbitGT1", index=0, value=value) if value > 0 else None

# KRingShift wrap restriction (packed value; see Solution.py):
# Require that any (k + KRingShift) wrap occurs only in tail loop (no main-loop wrap).
if key == "AssertKRingShiftTailWrapOnly":
return cls("KRingShiftTailWrapOnly", index=-1, value=value) if value > 0 else None

if key.endswith('Multiple'):
if value == 1:
return None
Expand Down
108 changes: 96 additions & 12 deletions projects/hipblaslt/tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
DSLoadU8, DSStore2B32, DSStore2B64, DSStoreB128, DSStoreB16, DSStoreB256, \
DSStoreB32, DSStoreB64, DSStoreB8, DSStoreInstruction, FlatLoadB128, FlatLoadB32, \
FlatLoadB64, FlatStoreB128, FlatStoreB32, FlatStoreB64, Instruction, MacroInstruction, \
MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpLeU32, \
MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpEQU32, SCmpLeU32, \
SMFMAInstruction, SNop, SSetPrior, SSetRegIMM32B32, SSubU32, SWaitCnt, SWaitAlu, \
SLongBranchPositive, VFmaMixF32, VMadMixF32, VMovB32, VAndB32, VCmpEQU32, VCndMaskB32, VMovB64
from rocisa.register import RegisterPool
Expand Down Expand Up @@ -3387,18 +3387,102 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.addComment1("Tail global read %s"%tc1)
if tailLoopOpt1st and (globalReadMode1st == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
# Keep per-tensor tail branching for tc2 when tc1 uses tailLoopOpt.
if kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B"):
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
else:
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
# skip wait for DTL if global load 1st is DTL
skip2ndWaitForDtl = kernel["DirectToLds%s"%tc1]
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
# KRS: If both tail global-read blocks (A/B) are eligible for KRS, do ONE runtime branch and
# share ONE set of labels for A/B. When sgprKRingShift==0, force both tail blocks down the
# original "load-only" path (no KRS_TAIL_OFFSET_* at all).
# skip wait for DTL if global load 1st is DTL
skip2ndWaitForDtl = kernel["DirectToLds%s"%tc1]
krsTailBranchable1 = kernel["KRingShift"] and kernel["BufferLoad"] and tc1 in ("A", "B")
krsTailBranchable2 = kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B") \
and not (tailLoopOpt2nd and (globalReadMode2nd == 2))
if krsTailBranchable1 and krsTailBranchable2:
labelNoKRS = Label(self.labels.getNameInc("KRS_tail_noop_AB"), "")
labelDoneKRS = Label(self.labels.getNameInc("KRS_tail_done_AB"), "")
labelNoKRS.comment = "KRS: tail no-KRS path for A/B (sgprKRingShift==0)"
labelDoneKRS.comment = "KRS: tail KRS branch join for A/B"

module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads (A+B)"))

# KRS-enabled path: A then B
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads (A+B)"))

# no-KRS path: A then B (load-only)
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
# Fallback: keep per-tensor tail branching.
if krsTailBranchable1:
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc1}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc1}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc1} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc1}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))

module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
if kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B"):
labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "")
labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "")
labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)"
labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}"
module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?"))
module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads"))
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False))
module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads"))
module.add(labelNoKRS)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True))
module.add(labelDoneKRS)
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))

doA = False
doB = False
Expand Down
Loading
Loading