diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/gfx950/Equality/gfx950_Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs.yaml b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/gfx950/Equality/gfx950_Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs.yaml index 61a6561b26c2..8855d2acf9a1 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/gfx950/Equality/gfx950_Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs.yaml +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/gfx950/Equality/gfx950_Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs.yaml @@ -347588,6 +347588,7 @@ AssertSummationElementMultiple: 1 AssignedDerivedParameters: true AssignedProblemIndependentDerivedParameters: true + BAddrInterleave: true BaseName: Cijk_Alik_Bljk_BBS_BH_BiasSB_HAS_SAV_UserArgs_MT128x192x128_MI1n-ZRL8mxl6uWAXsQRkqoULpPWXdDeqIc8yirhD35fpI= BufferLoad: true BufferStore: true @@ -347633,6 +347634,7 @@ Kernel: true KernelLanguage: Assembly KernelNameMin: Cijk_Alik_Bljk_BBS_BH_Bias_HA_S_SAV_UserArgs_MT128x192x128_MI16x16x1_SN_LDSB1_AFC0_AG0_AFEM1_AFEM1_ASEM1_CLR0_CADS0_DTLA0_DTLB0_DTVA0_DTVB0_EPS0_ELFLR0_EMLLn1_FDSI0_GRPM1_GRVWA8_GRVWB8_GSU0_GSUAMB_GLS0_ISA950_IU1_K1_LDSTI0_LBSPPA2048_LBSPPB256_LBSPPM0_LPA16_LPB16_LPM0_LRVW8_LWPMn1_MIAV0_MIWT8_3_MO40_NTn1_NTA0_NTB0_NTC1_NTD4_NTM0_NEPBS0_NLCA1_NLCB1_ONLL1_PGR2_PLR1_PKA1_SIA3_SS1_SPO0_SRVW0_SSO0_SVW8_SK3_SKFTR0_SKXCCM0_SGRO0_TLDS2_ULSGRO0_USL1_UIOFGRO0_USFGRO0_VSn1_VWA8_VWB1_WSGRA0_WSGRB0_WS64_WG16_16_1 + KRingShift: true LDSTrInst: false LSCA: 128 LSCB: 128 @@ -347745,7 +347747,7 @@ SourceSwap: 1 SpaceFillingAlgo: [] StaggerU: 16 - StaggerUMapping: 0 + StaggerUMapping: 1 StaggerUStride: 256 StorePriorityOpt: 0 StoreRemapVectorWidth: 0 diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py index 989bb17279db..a305ec342907 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py @@ -363,12 +363,22 @@ {"DirectToVgprA": [False]}, {"DirectToVgprB": [False]}, {"DirectToVgprSparseMetadata": [False]}, + # Restricted address remap features (default off unless explicitly enabled in the solution config): + {"BAddrInterleave": [False]}, + {"KRingShift": [False]}, {"DirectToLds": [0]}, {"UseSgprForGRO": [-1]}, {"UseInstOffsetForGRO": [0]}, {"AssertSummationElementMultiple": [1]}, {"AssertFree0ElementMultiple": [1]}, {"AssertFree1ElementMultiple": [1]}, + # Address-interleave restriction (default disabled): + # When >0, the solution requires tiles1=(SizeJ/MT1) to have lowbit(tiles1)>1 (i.e. G>1), + {"AssertFree1DivByMT1LowbitGT1": [0]}, + # KRingShift wrap restriction (default disabled): + # Encodes a runtime predicate that ensures (k + KRingShift) does not wrap in main loop + # (wrap is allowed only in tail loop where codegen applies the correction). + {"AssertKRingShiftTailWrapOnly": [0]}, {"AssertAIGreaterThanEqual": [-1]}, {"AssertAILessThanEqual": [-1]}, {"StaggerU": [32]}, # recommend [0,32] diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py index 37bf791cfaf7..a7f91fd21896 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py @@ -325,6 +325,14 @@ def makeValidMatrixInstructions(): "DirectToVgprA": [False, True], "DirectToVgprB": [False, True], "DirectToVgprSparseMetadata": [False, True], + # B address interleave (restricted): non-contiguous tile columns for TN/NN-like B (TLUB == False), + # with runtime G chosen as the largest power-of-two factor of (N/MT1), capped by LVCB. + # Requires SizeJ % MT1 == 0 at runtime; otherwise falls back to original mapping. + "BAddrInterleave": [False, True], + # K ring-shift (restricted): apply a per-WG shift along the summation (K) dimension so that + # the B-side base K address for each workgroup is cacheline-aligned/congruent, while preserving + # correctness via tail-loop ring wrap. Intended for TN/NN-like B (TLUB == False). + "KRingShift": [False, True], # Attempt to load directly from global memory into LDS. # Assembly only # Requires BufferLoad, assembler support for lds modifier on buffer @@ -434,6 +442,16 @@ def makeValidMatrixInstructions(): # - See above AssertFree0ElementMultiple "Load optimizations" # 1 indicates no assertion (since all sizes are multiples of 1) "AssertFree1ElementMultiple": [1, 2, 4, 8, 16], + # Address-interleave restriction: + # If >0, require tiles1=(Free1Size / MT1) to have lowbit(tiles1)>1 (i.e. G>1). + # This matches the kernel's initBInterleaveG logic: + # - require Free1Size % MT1 == 0 + # - compute lowbit(tiles1) + # - enable only if min(lowbit, LVCB) > 1 + "AssertFree1DivByMT1LowbitGT1": -1, + # KRingShift wrap restriction (packed integer; see Solution.py for encoding): + # If >0, require any (k + KRingShift) wrap to occur only in tail loop (no main-loop wrap). + "AssertKRingShiftTailWrapOnly": -1, # Assertions that require arithmetic intensity to be specified value. # Arithmetic intensity measures the ratio of computation to memory bandwidth required for a problem. # These predicates can be used to adjust solution selection compute-bound or memory-bound problems. diff --git a/projects/hipblaslt/tensilelite/Tensile/Contractions.py b/projects/hipblaslt/tensilelite/Tensile/Contractions.py index 97ac83f69d07..be56cdb086ec 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Contractions.py +++ b/projects/hipblaslt/tensilelite/Tensile/Contractions.py @@ -442,6 +442,16 @@ def FromOriginalKeyPair(cls, pair): if key == "AssertAILessThanEqual": return cls("AILessThanEqual", value=value) if value > 0 else None + # Address-interleave restriction: + # Require tiles1 = Free1Size / MT1 to be a power-of-two (and divisible). + if key == "AssertFree1DivByMT1LowbitGT1": + return cls("Free1SizeDivByValueLowbitGT1", index=0, value=value) if value > 0 else None + + # KRingShift wrap restriction (packed value; see Solution.py): + # Require that any (k + KRingShift) wrap occurs only in tail loop (no main-loop wrap). + if key == "AssertKRingShiftTailWrapOnly": + return cls("KRingShiftTailWrapOnly", index=-1, value=value) if value > 0 else None + if key.endswith('Multiple'): if value == 1: return None diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 1b47b6fb0ab2..75ab57d5c668 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -34,7 +34,7 @@ DSLoadU8, DSStore2B32, DSStore2B64, DSStoreB128, DSStoreB16, DSStoreB256, \ DSStoreB32, DSStoreB64, DSStoreB8, DSStoreInstruction, FlatLoadB128, FlatLoadB32, \ FlatLoadB64, FlatStoreB128, FlatStoreB32, FlatStoreB64, Instruction, MacroInstruction, \ - MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpLeU32, \ + MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpEQU32, SCmpLeU32, \ SMFMAInstruction, SNop, SSetPrior, SSetRegIMM32B32, SSubU32, SWaitCnt, SWaitAlu, \ SLongBranchPositive, VFmaMixF32, VMadMixF32, VMovB32, VAndB32, VCmpEQU32, VCndMaskB32, VMovB64 from rocisa.register import RegisterPool @@ -3387,18 +3387,102 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): module.addComment1("Tail global read %s"%tc1) if tailLoopOpt1st and (globalReadMode1st == 2): module.add(self.doTailLoopOpt(kernel, tensorParameters1st)) + module.addComment1("Update M0 for DTLDS") + moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, True) + module.add(replaceHolder(moduleTmp, 0)) + module.addComment1("Tail global read %s"%tc2) + if tailLoopOpt2nd and (globalReadMode2nd == 2): + module.add(self.doTailLoopOpt(kernel, tensorParameters2nd)) + else: + # Keep per-tensor tail branching for tc2 when tc1 uses tailLoopOpt. + if kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B"): + labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "") + labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "") + labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)" + labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}" + module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?")) + module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads")) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False)) + module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads")) + module.add(labelNoKRS) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True)) + module.add(labelDoneKRS) + else: + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd)) else: - module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st)) - module.addComment1("Update M0 for DTLDS") - # skip wait for DTL if global load 1st is DTL - skip2ndWaitForDtl = kernel["DirectToLds%s"%tc1] - moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl) - module.add(replaceHolder(moduleTmp, 0)) - module.addComment1("Tail global read %s"%tc2) - if tailLoopOpt2nd and (globalReadMode2nd == 2): - module.add(self.doTailLoopOpt(kernel, tensorParameters2nd)) - else: - module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd)) + # KRS: If both tail global-read blocks (A/B) are eligible for KRS, do ONE runtime branch and + # share ONE set of labels for A/B. When sgprKRingShift==0, force both tail blocks down the + # original "load-only" path (no KRS_TAIL_OFFSET_* at all). + # skip wait for DTL if global load 1st is DTL + skip2ndWaitForDtl = kernel["DirectToLds%s"%tc1] + krsTailBranchable1 = kernel["KRingShift"] and kernel["BufferLoad"] and tc1 in ("A", "B") + krsTailBranchable2 = kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B") \ + and not (tailLoopOpt2nd and (globalReadMode2nd == 2)) + if krsTailBranchable1 and krsTailBranchable2: + labelNoKRS = Label(self.labels.getNameInc("KRS_tail_noop_AB"), "") + labelDoneKRS = Label(self.labels.getNameInc("KRS_tail_done_AB"), "") + labelNoKRS.comment = "KRS: tail no-KRS path for A/B (sgprKRingShift==0)" + labelDoneKRS.comment = "KRS: tail KRS branch join for A/B" + + module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?")) + module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads (A+B)")) + + # KRS-enabled path: A then B + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False)) + module.addComment1("Update M0 for DTLDS") + moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl) + module.add(replaceHolder(moduleTmp, 0)) + module.addComment1("Tail global read %s"%tc2) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False)) + module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads (A+B)")) + + # no-KRS path: A then B (load-only) + module.add(labelNoKRS) + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True)) + module.addComment1("Update M0 for DTLDS") + moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl) + module.add(replaceHolder(moduleTmp, 0)) + module.addComment1("Tail global read %s"%tc2) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True)) + module.add(labelDoneKRS) + else: + # Fallback: keep per-tensor tail branching. + if krsTailBranchable1: + labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc1}"), "") + labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc1}"), "") + labelNoKRS.comment = f"KRS: tail no-KRS path for {tc1} (sgprKRingShift==0)" + labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc1}" + module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?")) + module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads")) + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=False)) + module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads")) + module.add(labelNoKRS) + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st, krTailForceDisable=True)) + module.add(labelDoneKRS) + else: + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st)) + + module.addComment1("Update M0 for DTLDS") + moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd, skip2ndWaitForDtl) + module.add(replaceHolder(moduleTmp, 0)) + module.addComment1("Tail global read %s"%tc2) + if tailLoopOpt2nd and (globalReadMode2nd == 2): + module.add(self.doTailLoopOpt(kernel, tensorParameters2nd)) + else: + if kernel["KRingShift"] and kernel["BufferLoad"] and tc2 in ("A", "B"): + labelNoKRS = Label(self.labels.getNameInc(f"KRS_tail_noop_{tc2}"), "") + labelDoneKRS = Label(self.labels.getNameInc(f"KRS_tail_done_{tc2}"), "") + labelNoKRS.comment = f"KRS: tail no-KRS path for {tc2} (sgprKRingShift==0)" + labelDoneKRS.comment = f"KRS: tail KRS branch join for {tc2}" + module.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?")) + module.add(SCBranchSCC1(labelName=labelNoKRS.getLabelName(), comment="KRS: take no-KRS tail loads")) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=False)) + module.add(SBranch(labelName=labelDoneKRS.getLabelName(), comment="KRS: skip no-KRS tail loads")) + module.add(labelNoKRS) + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd, krTailForceDisable=True)) + module.add(labelDoneKRS) + else: + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd)) doA = False doB = False diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index 0ed7316dd8e7..192b2a6f09cf 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -1339,6 +1339,9 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: src0=destLo, src1=pendingOffset, \ comment="accumulate final pendingOffset")) + if kernel["KRingShift"] and tc in ("A", "B"): + macro.add(VAddU32(dst=vgpr("Addr+0", isMacro=True), src0="v[\\vgprAddr+0]", src1=sgpr("KRingShift"), + comment="KRS: KRingShift addr += shift")) if tP != None and kernel["BufferLoad"] and self.states.srdShiftLeft[tc]: macro.add(VAddU32(dst=vgpr("Addr+0", isMacro=True), \ @@ -1780,6 +1783,24 @@ def defineAndResources(self, kernel, tPA, tPB, tPM): if kernel["ProblemType"]["SupportUserArgs"]: moduleRegInit.add(SMovB32(dst=sgpr("ArgType"),src=sgpr(sgprArgType))) + # B address interleave (restricted) - compute runtime G once and reuse later. + if kernel["BAddrInterleave"]: + moduleRegInit.addComment1("Interleave: define SGPR and init runtime G once") + sgprG = self.defineSgprIdx("BInterleaveG", 1) + if "BInterleaveG" not in self.states.nonPostLoopSgpr: + self.states.nonPostLoopSgpr.append("BInterleaveG") + moduleRegInit.add(RegSet("s", "sgprBInterleaveG", sgprG)) + moduleRegInit.addModuleAsFlatItems(self.initBInterleaveG(kernel)) + + # K ring-shift (restricted) - compute per-WG shift once and reuse later. + if kernel["KRingShift"]: + moduleRegInit.addComment1("KRS: KRingShift define SGPR and init per-WG shift once") + sgprShift = self.defineSgprIdx("KRingShift", 1) + # Keep this SGPR live into post-loop (tail + store) - prevent endSummation from undefining it. + if "KRingShift" not in self.states.nonPostLoopSgpr: + self.states.nonPostLoopSgpr.append("KRingShift") + moduleRegInit.add(RegSet("s", "sgprKRingShift", sgprShift)) + self.sgprPool.checkIn(sgprPackedArgs) if kernel["StorePriorityOpt"]: @@ -2467,6 +2488,142 @@ def setTo(self, dstBase: Union[int, str], srcBase: Union[int, str]) -> Module: module.add(SMovB32(dst=sgpr(dstBase), src=sgpr(srcBase), comment="Set to %s"%(srcBase))) return module + def initBInterleaveG(self, kernel) -> Module: + """ + Compute interleave runtime G once and keep in fixed SGPRs for reuse. + - sgprBInterleaveG : G (power-of-two, capped by LVCB, or 1 when disabled) + """ + module = Module("initBInterleaveG") + module.addComment1("Interleave: init G once (reuse across groB/loadSRD/storeSRD/storeStride)") + + # Host-side predicate guarantees BAddrInterleave solutions only run when the computed G>1, + # so we don't need to initialize BInterleaveG to 1 here. + + mt1 = kernel["MacroTile1"] + lvcb = kernel["LVCB"] + + with self.allocTmpSgpr(6) as tS: + tmpDiv = tS.idx + tmpDivRes = ContinuousRegister(tmpDiv, 2) + sTiles = tmpDiv + 2 + sRem = tmpDiv + 3 + sNeg = tmpDiv + 4 + sLow = tmpDiv + 5 + + # Host-side predicate guarantees SizeJ % MT1 == 0 for BAddrInterleave solutions, + # so we only need the quotient tiles1 = SizeJ / MT1 (no remainder/guard). + module.add(scalarStaticDivideAndRemainder(qReg=sTiles, rReg=sRem, dReg="SizeJ", divisor=mt1, tmpSgprRes=tmpDivRes, doRemainder=0)) + + module.add(SSubU32(dst=sgpr(sNeg), src0=0, src1=sgpr(sTiles), comment="-tiles1")) + module.add(SAndB32(dst=sgpr(sLow), src0=sgpr(sTiles), src1=sgpr(sNeg), comment="lowbit(tiles1)")) + module.add(SCmpLeU32(src0=sgpr(sLow), src1=lvcb, comment="lowbit <= LVCB ?")) + module.add(SCSelectB32(dst=sgpr("BInterleaveG"), src0=sgpr(sLow), src1=lvcb, comment="G = min(lowbit,LVCB)")) + return module + + def initKRingShift(self, kernel) -> Module: + """ + Compute per-workgroup K ring-shift amount (in elements) for cacheline congruence. + shift = (-baseOffsetElems) mod cacheLineElements + where: + cacheLineElements = vL1DCacheLineBytes / bpe + + Stored in: + sgprKRingShift (elements) + + NOTE: + Here baseOffsetElems is derived from sgprSrdB (the per-WG B SRD base, i.e. AddressB + tileStart), + so it already reflects the WG-dependent starting address. Do not add an extra WG term again. + """ + module = Module("initKRingShift") + module.addComment1("KRS: KRingShift init per-WG K shift (elements) for cacheline congruence") + + # vL1DCacheLineBytes is provided by rocisa archCaps (see rocisa/include/hardware_caps.hpp). + cacheLineBytes = int(self.states.archCaps["vL1DCacheLineBytes"]) + if cacheLineBytes <= 0: + module.addComment0("KRingShift: no arch cacheline info; keep shift=0") + module.add(SMovB32(dst=sgpr("KRingShift"), src=0, comment="KRS: disabled (shift=0)")) + return module + + # bpe of B elements (bytes/element) + # NOTE: Use ProblemType datatype rather than a state attribute (bpeB isn't a stable StateValues field). + bpe = int(kernel["ProblemType"]["DataTypeB"].numBytes()) + if bpe <= 0 or (cacheLineBytes % bpe) != 0: + module.addComment0("KRingShift: cacheLineBytes%bpe!=0; keep shift=0") + module.add(SMovB32(dst=sgpr("KRingShift"), src=0, comment="KRS: disabled (shift=0)")) + return module + + cacheLineElements = cacheLineBytes // bpe + # Require power-of-two for cheap modulo. + if cacheLineElements & (cacheLineElements - 1): + module.addComment0("KRingShift: cacheLineElements not pow2; keep shift=0") + module.add(SMovB32(dst=sgpr("KRingShift"), src=0, comment="KRS: disabled (shift=0)")) + return module + + mask = cacheLineElements - 1 + + labelDone = Label(self.labels.getNameInc("KRingShift_done"), "") + labelDone.comment = "KRS: KRingShift done" + with self.allocTmpSgpr(2) as tS: + sTmp = tS.idx + sBase = tS.idx + 1 + + # Include current B tile base address (SRD base) misalignment in the shift. + # We want: + # (baseOffsetElems + shift) % cacheLineElements == 0 + # => shift = -baseOffsetElems mod cacheLineElements + baseMaskBytes = cacheLineBytes - 1 + # NOTE: AddressB is pre-padded (see "pre-pad to make room for possible pointer shift"), + # so SrdB.base includes that subtraction. Add the pre-pad back to recover the true base. + prePadBytes = int(self.states.srdShiftLeft["B"]) * int(kernel["ProblemType"]["DataTypeB"].numBytes()) + if prePadBytes: + module.add(SAddU32(dst=sgpr(sBase), src0=sgpr("SrdB+0"), src1=prePadBytes, comment="KRS: unpad B base (lo)")) + else: + module.add(SMovB32(dst=sgpr(sBase), src=sgpr("SrdB+0"), comment="KRS: B base (lo)")) + module.add(SAndB32(dst=sgpr(sBase), src0=sgpr(sBase), src1=baseMaskBytes, + comment=f"KRS: baseBytes = (SrdB.base + prePad) & {baseMaskBytes}")) + if bpe > 0 and (bpe & (bpe - 1)) == 0: + module.add(SLShiftRightB32(dst=sgpr(sBase), src=sgpr(sBase), shiftHex=hex(log2(bpe)), + comment="KRS: baseOffsetElems = baseBytes >> log2(bpe)")) + else: + # Should not happen for supported datatypes, but keep behavior safe. + module.addComment0("KRingShift: bpe not pow2; keep baseOffsetElems=0") + module.add(SMovB32(dst=sgpr(sBase), src=0, comment="KRS: baseOffsetElems = 0")) + + # tmp = baseOffsetElems & mask + module.add(SAndB32(dst=sgpr(sTmp), src0=sgpr(sBase), src1=mask, comment="KRS: baseOffsetElems mod cacheLineElements")) + + # shift = (-tmp) & mask + module.add(SSubU32(dst=sgpr("KRingShift"), src0=0, src1=sgpr(sTmp), comment="KRS: shift = -tmp")) + module.add(SAndB32(dst=sgpr("KRingShift"), src0=sgpr("KRingShift"), src1=mask, comment="KRS: shift %= cacheLineElements")) + + # If sgprKRingShift is not aligned to GRVW(A/B), disable KRS (set shift=0). + # Requested behavior: + # if (KRingShift % GRVWA != 0) or (KRingShift % GRVWB != 0) then KRingShift = 0 + grvwA = int(kernel["GlobalReadVectorWidthA"]) + grvwB = int(kernel["GlobalReadVectorWidthB"]) + maskA = (grvwA - 1) if grvwA > 1 else 0 + maskB = (grvwB - 1) if grvwB > 1 else 0 + if maskA or maskB: + # Fast path requires power-of-two GRVWs for cheap modulo via AND-mask. + if ((grvwA & (grvwA - 1)) != 0) or ((grvwB & (grvwB - 1)) != 0): + module.addComment0("KRingShift: GRVW not pow2; disable shift=0") + module.add(SMovB32(dst=sgpr("KRingShift"), src=0, comment="KRS: disabled (shift=0)")) + else: + # sTmp := (shift & (grvwA-1)) | (shift & (grvwB-1)) + if maskA: + module.add(SAndB32(dst=sgpr(sTmp), src0=sgpr("KRingShift"), src1=maskA, comment=f"KRS: shift % GRVWA (mask=0x{maskA:x})")) + else: + module.add(SMovB32(dst=sgpr(sTmp), src=0, comment="KRS: GRVWA==1 => aligned")) + if maskB: + module.add(SAndB32(dst=sgpr(sBase), src0=sgpr("KRingShift"), src1=maskB, comment=f"KRS: shift % GRVWB (mask=0x{maskB:x})")) + module.add(SOrB32(dst=sgpr(sTmp), src0=sgpr(sTmp), src1=sgpr(sBase), comment="KRS: (shift%GRVWA) | (shift%GRVWB)")) + module.add(SCmpEQU32(src0=sgpr(sTmp), src1=0, comment="KRS: (shift%GRVWA==0 && shift%GRVWB==0) ?")) + module.add(SCSelectB32(dst=sgpr("KRingShift"), src0=sgpr("KRingShift"), src1=0, + comment="KRS: if misaligned, disable shift=0")) + + module.add(labelDone) + return module + ############################################################################## # Global Read Addresses: Tile Offsets A/B ############################################################################## @@ -2476,6 +2633,7 @@ def graTileOffsets(self, kernel, tP, margin=-1): tP["vgprPackedOffsets"] = None tP["vgprTileOffsetsCheckOut"] = False tP["numVgprTileOffsets"] = 0 + skipGroTileOffsetsLoop = False if kernel["_UseSgprForGRO"]: # Let the vgprTileOffsets checkin handle tReg later since these are same vgpr tP["vgprTileOffsets"] = tP["gpr"]["tReg"] @@ -2503,7 +2661,39 @@ def graTileOffsets(self, kernel, tP, margin=-1): module.add(VLShiftLeftB32(dst=vgpr(v), shiftHex=hex(log2(margin)), src=vgpr(v), comment="gro%s%s_%u *= %d"%(tP["tensorChar"], tP["tileChar"], 0, margin))) else: if not tP["isSwizzled"]: - module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["tReg"]), comment="gro%s%s_%u"%(tP["tensorChar"], tP["tileChar"], 0) )) + # B address interleave (restricted): non-contiguous tile columns with runtime G based on (SizeJ/MT1). + useBInterleave = bool(kernel["BAddrInterleave"]) + useBInterleave = useBInterleave and tP["isB"] and (not tP["tlu"]) + + if useBInterleave: + # Host-side predicate guarantees computed G>1 for BAddrInterleave solutions. + # G = min(LVCB, lowbit(SizeJ/MT1)), lowbit(x)=x&-x (largest power-of-two divisor). + # Addressing: + # baseCol = (wg1/G)*(MT1*G) + (wg1%G) + # groB1J(r,l) = G*(r + l*LSPB) + # + # This partitions columns across workgroups without changing host launch (requires SizeJ multiple of MT1). + lspb = kernel[tP["lsp"]] + + # Vector: groB1J_0 = r*G, step = G*LSPB + gV = self.vgprPool.checkOut(1) + stepV = self.vgprPool.checkOut(1) + module.add(VMovB32(dst=vgpr(gV), src=sgpr("BInterleaveG"), comment="G")) + if (lspb & (lspb - 1)) == 0: + module.add(VLShiftLeftB32(dst=vgpr(stepV), shiftHex=hex(log2(lspb)), src=vgpr(gV), comment="step=G*LSPB")) + else: + module.add(VMulLOU32(dst=vgpr(stepV), src0=vgpr(gV), src1=hex(lspb), comment="step=G*LSPB")) + + module.add(VMulLOU32(dst=vgpr(v), src0=vgpr(tP["gpr"]["tReg"]), src1=vgpr(gV), comment="groB1J_0 = r*G")) + for l in range(1, tP["nrt"]): + module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=vgpr(stepV), src1=vgpr(v+l-1), + comment="groB1J_%u += step(G*LSPB)" % l)) + self.vgprPool.checkIn(stepV) + self.vgprPool.checkIn(gV) + skipGroTileOffsetsLoop = True + + else: + module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["tReg"]), comment="gro%s%s_%u"%(tP["tensorChar"], tP["tileChar"], 0) )) else: lsu = kernel["LocalSplitU"] # localSplitU if tP["isA"]: @@ -2561,23 +2751,25 @@ def graTileOffsets(self, kernel, tP, margin=-1): comment="swzBlkVWOffset = swzBlkWvGSize - laneSize * (VW - 1)")) module.add(VMovB32(dst=vgpr(swzBlkVWSizeVgpr), src=sgpr(swzBlkVWSizeSgpr)) ) - for l in range(1, tP["nrt"]): - strideValue = stride - if strideInterleave and (l & strideMask) != 0: - strideValue = 1 - if not tP["isSwizzled"]: - module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=strideValue, \ - src1=vgpr(v+l-1), comment="gro%s%s_%u += %s"%(tP["tensorChar"], tP["tileChar"], l, strideIdx) )) - # swizzle - else: - # VW > 1 - if (strideInterleave and (l & strideMask) != 0): - module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=laneSize, \ - src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u"%(tc, tP["tensorChar"], tP["tileChar"], l) )) - # VW == 1 + # If groB1J_* was emitted above (interleave path), skip this generic loop to avoid duplicate emissions. + if not skipGroTileOffsetsLoop: + for l in range(1, tP["nrt"]): + strideValue = stride + if strideInterleave and (l & strideMask) != 0: + strideValue = 1 + if not tP["isSwizzled"]: + module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=strideValue, \ + src1=vgpr(v+l-1), comment="gro%s%s_%u += %s"%(tP["tensorChar"], tP["tileChar"], l, strideIdx) )) + # swizzle else: - module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=vgpr(swzBlkVWSizeVgpr), \ - src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u"%(tc, tP["tensorChar"], tP["tileChar"], l) )) + # VW > 1 + if (strideInterleave and (l & strideMask) != 0): + module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=laneSize, \ + src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u"%(tc, tP["tensorChar"], tP["tileChar"], l) )) + # VW == 1 + else: + module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=vgpr(swzBlkVWSizeVgpr), \ + src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u"%(tc, tP["tensorChar"], tP["tileChar"], l) )) if tP["isSwizzled"]: self.vgprPool.checkIn(swzBlkVWSizeVgpr) @@ -2896,6 +3088,9 @@ def graFinalOffsets(self, kernel, tP): graIdx = 0 swapPerpPara = (((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and (not tP["tlu"]) and tP["nrp"] > 1) + if kernel["KRingShift"] and tc == "A": + module.addModuleAsFlatItems(self.initKRingShift(kernel)) + # both UseSgprForGRO and DTVA/B are enabled if ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and kernel["_UseSgprForGRO"]: if tP["tlu"]: @@ -3417,6 +3612,50 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): # This is guaranteed to fit in 32-bit since the WG*MT is a number of elements in some unsigned direction: module.addModuleAsFlatItems(self.s_mul_u64_u32(sgpr(tileStart+0), sgpr(tileStart+1), sgpr(tP["wg"]), kernel[tP["mt"]], comment="WorkGroup[01] * MT")) + + # Interleave (restricted): for B (tlu==False), overwrite wg1*MT1 with baseCol: + # baseCol = (wg1/G)*(MT1*G) + (wg1%G), G=min(lowbit(SizeJ/MT1), LVCB) + # Host-side predicate guarantees computed G>1 for BAddrInterleave solutions. + if kernel["BAddrInterleave"] and tP["isB"] and (not tP["tlu"]) and (tP["wg"] == "WorkGroup1"): + labelCase16 = Label(self.labels.getNameInc("BInterleave_loadSrd_case16"), "") + labelCase8 = Label(self.labels.getNameInc("BInterleave_loadSrd_case8"), "") + labelCase4 = Label(self.labels.getNameInc("BInterleave_loadSrd_case4"), "") + labelCase2 = Label(self.labels.getNameInc("BInterleave_loadSrd_case2"), "") + labelAfterShift = Label(self.labels.getNameInc("BInterleave_loadSrd_afterShift"), "") + + # mask = G-1 + module.add(SSubU32(dst=sgpr(stmp+0), src0=sgpr("BInterleaveG"), src1=1, comment="mask=G-1")) + module.add(SAndB32(dst=sgpr(tileStart+0), src0=sgpr("WorkGroup1"), src1=sgpr(stmp+0), comment="phase = wg1&(G-1)")) + + # super = wg1 >> log2(G) (G in {2,4,8,16}) + module.add(SMovB32(dst=sgpr(stmp+1), src=sgpr("WorkGroup1"), comment="super = wg1")) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=16)) + module.add(SCBranchSCC1(labelName=labelCase16.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=8)) + module.add(SCBranchSCC1(labelName=labelCase8.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=4)) + module.add(SCBranchSCC1(labelName=labelCase4.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=2)) + module.add(SCBranchSCC1(labelName=labelCase2.getLabelName())) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase16) + module.add(SLShiftRightB32(dst=sgpr(stmp+1), src=sgpr(stmp+1), shiftHex=hex(4))) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase8) + module.add(SLShiftRightB32(dst=sgpr(stmp+1), src=sgpr(stmp+1), shiftHex=hex(3))) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase4) + module.add(SLShiftRightB32(dst=sgpr(stmp+1), src=sgpr(stmp+1), shiftHex=hex(2))) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase2) + module.add(SLShiftRightB32(dst=sgpr(stmp+1), src=sgpr(stmp+1), shiftHex=hex(1))) + module.add(labelAfterShift) + + # baseCol = phase + super*(MT1*G) + module.add(SMulI32(dst=sgpr(stmp+0), src0=sgpr("BInterleaveG"), src1="MT1", comment="MT1*G")) + module.add(SMulI32(dst=sgpr(stmp+0), src0=sgpr(stmp+0), src1=sgpr(stmp+1), comment="super*(MT1*G)")) + module.add(SAddU32(dst=sgpr(tileStart+0), src0=sgpr(tileStart+0), src1=sgpr(stmp+0), comment="baseCol")) + module.add(SMovB32(dst=sgpr(tileStart+1), src=0)) strideF = self.strideRef(tc, tP['tileIdx']) if not self.isConstUnitStride(strideF): module.addModuleAsFlatItems(self.s_mul_u64_u32(sgpr(tileStart), sgpr(tileStart+1), sgpr(tileStart+0), \ @@ -4647,6 +4886,28 @@ def removeStagger(self, kernel, tP): imod.add(SSubU32(dst=sgpr(tmp), src0=sgpr(tmp), src1=sgpr("WrapU%s"%tc), comment="S - WrapU")) imod.add(SSubBU32(dst=sgpr(tmp+1), src0=sgpr(tmp+1), src1=sgpr("WrapU%s+1"%(tc)), comment="S - WrapU")) + # KRingShift: fold the "rebias by mainLoopBytes" into the existing (S - WrapU) increment. + # This removes the need for a separate SRD -= mainLoopBytes / limit += mainLoopBytes block. + # tmp:tmp+1 holds the 64-bit SRD byte increment (S - WrapU). + if kernel["KRingShift"] and tc in ("A", "B") and kernel["BufferLoad"]: + depthU = int(kernel["DepthU"]) + sizeK = "SizesSum+%u" % self.states.unrollIdx + bpe = int(kernel["ProblemType"]["DataType%s"%tc].numBytes()) + maskDU = (~(depthU - 1)) & 0xFFFFFFFF + # Use tmp+2 as scratch for mainLoopBytes (safe: tmpIncSparse is recomputed later if needed). + imod.add(SAndB32(dst=sgpr(tmp+2), src0=sgpr(sizeK), src1=hex(maskDU), comment="KRS: mainLoopElems")) + if bpe > 0 and (bpe & (bpe - 1)) == 0: + imod.add(SLShiftLeftB32(dst=sgpr(tmp+2), src=sgpr(tmp+2), shiftHex=hex(log2(bpe)), comment="KRS: mainLoopBytes")) + else: + imod.add(SMulI32(dst=sgpr(tmp+2), src0=sgpr(tmp+2), src1=bpe, comment="KRS: mainLoopBytes")) + # If KRS is disabled at runtime (sgprKRingShift==0), do NOT apply any mainLoopBytes rebias. + # Keep the sequence branchless by cselect'ing the delta to 0. + imod.add(SCmpEQU32(src0=sgpr("KRingShift"), src1=0, comment="KRS: sgprKRingShift==0 ?")) + imod.add(SCSelectB32(dst=sgpr(tmp+2), src0=0, src1=sgpr(tmp+2), comment="KRS: mainLoopBytesDelta (0 if disabled)")) + # incBytes -= mainLoopBytes (64-bit subtract with borrow into hi) + imod.add(SSubU32(dst=sgpr(tmp), src0=sgpr(tmp), src1=sgpr(tmp+2), comment="KRS: incBytes -= mainLoopBytes (lo)")) + imod.add(SSubBU32(dst=sgpr(tmp+1), src0=sgpr(tmp+1), src1=0, comment="KRS: incBytes -= mainLoopBytes (hi)")) + imod.add(self.incrementSrd(tP, sgpr(tmp), sgpr(tmp+1))) if kernel["ProblemType"]["Sparse"] and \ @@ -4985,6 +5246,26 @@ def tailLoopGlobalRead(self, kernel, tPA, tPB, doA, doB): (lastRegTag in tagList): imod.add(self.undefineSgpr(regTag)) + # KRS: release symbol aliases for the temporary KRS registers. + # These are defined via RegSet/TextBlock (not via self.sgprs pool), so we only UNDEF the names here. + if kernel["KRingShift"] and kernel["BufferLoad"]: + imod.addComment1("KRS: release KRS temp symbol aliases") + for n in ( + "sgprKrsNumChunk", + "sgprKrsKRingShiftBytes", + "sgprKrsTailStartChunk", + "sgprKrsOobS", + "sgprKrsMainLoopBytes", + "sgprKrsMaskValid", + "sgprKrsMaskHead", + "vgprKrsChunk", + "vgprKrsOobV", + "vgprKrsScratch", + ): + imod.add(ValueSet(name=n, value="UNDEF", format=-1)) + # KRS: also release the real sgprKRingShift here (proper pool + ValueSet handling). + imod.add(self.undefineSgpr("KRingShift")) + loadALabel = Label(label="LoadA", comment="") loadBLabel = Label(label="LoadB", comment="") mergeALabel = Label(label="MergeA", comment="") @@ -8558,7 +8839,7 @@ def directToLdsM0Update(self, kernel, mode, tP, skipWait = False): # Global Read: Do It A/B ############################################################################## def globalReadDo(self, kernel, mode, tP, unrollLoopIdx=-1, g2lBufIdx=0, \ - doTailOpt = 0, optParams = None): + doTailOpt = 0, optParams = None, krTailForceDisable=False): tc = tP["tensorChar"] problemType = self.states.kernel["ProblemType"] imod = StructuredModule("globalReadDo%s_%u"%(tc,mode)) @@ -8629,6 +8910,80 @@ def globalReadBody(tP): isLds = True if kernel["DirectToLds%s"%tc] else False isTr = (tc == "A" or tc == "B") and kernel["enableGLTr%s"%tc] + # KRingShift: in tail loop, patch each vgprGlobalReadOffset{A,B}+i just-in-time right before + # its corresponding buffer_load. This allows interleaving apply/load and avoids a big apply-only block. + krTailJIT = (not krTailForceDisable and self.states.inTailLoop and kernel["KRingShift"] and kernel["BufferLoad"] + and tc in ("A", "B") and not kernel["_UseSgprForGRO"]) + if krTailJIT: + # Must be even-aligned since macros use b64 SGPR pairs. + krTmpS = self.sgprPool.checkOutAligned(10, 2, f"krTailJITTmpS{tc}", preventOverflow=False) + krTmpV = self.vgprPool.checkOutAligned(4, 2, f"krTailJITTmpV{tc}", self.states.preventVgprOverflowDuringNewTile) + imod.header.addComment0(f"KRS: tail JIT setup for {tc} offsets (setup once; apply before each load)") + # Emit symbolic register aliases for readability (outside macro bodies). + # Define symbolic registers in the standard Tensile style (like sgprKRingShift). + imod.header.add(RegSet("s", "sgprKrsNumChunk", krTmpS + 0)) + imod.header.add(RegSet("s", "sgprKrsKRingShiftBytes", krTmpS + 1)) + imod.header.add(RegSet("s", "sgprKrsTailStartChunk", krTmpS + 2)) + imod.header.add(RegSet("s", "sgprKrsOobS", krTmpS + 3)) + imod.header.add(RegSet("s", "sgprKrsMainLoopBytes", krTmpS + 4)) + imod.header.add(RegSet("s", "sgprKrsMaskValid", krTmpS + 6)) + imod.header.add(RegSet("s", "sgprKrsMaskHead", krTmpS + 8)) + + imod.header.add(RegSet("v", "vgprKrsChunk", krTmpV + 0)) + imod.header.add(RegSet("v", "vgprKrsOobV", krTmpV + 1)) + imod.header.add(RegSet("v", "vgprKrsScratch", krTmpV + 3)) + + # KRS: inline SETUP (no macro) for this tc. + imod.header.addComment0(f"KRS: inline tail-offset setup for {tc} offsets") + + depthU = int(kernel["DepthU"]) + maskDU = (~(depthU - 1)) & 0xFFFFFFFF + bpeBytes = int(kernel["ProblemType"][f"DataType{tc}"].numBytes()) + bpeShift = int(log2(bpeBytes)) + chunkElems = int(kernel[f"GlobalReadVectorWidth{tc}"]) + if (chunkElems & (chunkElems - 1)) != 0: + raise RuntimeError(f"KRS: GlobalReadVectorWidth{tc} must be power-of-two, got {chunkElems}") + chunkElemShift = int(log2(chunkElems)) + ceilBias = chunkElems - 1 + if ceilBias != 0: + imod.header.add(SAddU32(dst=sgpr("KrsNumChunk", 1, False), src0=sgpr("LoopCounterL"), src1=ceilBias, + comment=f"KRS: numChunk = ceil(LoopCounterL/{chunkElems}) ; bias=+{ceilBias} elems")) + imod.header.add(SLShiftRightB32(dst=sgpr("KrsNumChunk", 1, False), shiftHex=hex(chunkElemShift), src=sgpr("KrsNumChunk", 1, False), + comment=f"KRS: numChunk >>= {chunkElemShift} (chunkElems={chunkElems})")) + else: + imod.header.add(SMovB32(dst=sgpr("KrsNumChunk", 1, False), src=sgpr("LoopCounterL"), + comment="KRS: numChunk = LoopCounterL (chunkElems==1)")) + imod.header.add(VAndB32(dst=vgpr("KrsChunk", 1, False), src0=0x0F, src1=vgpr("Serial"), + comment="KRS: chunk = vgprSerial & 0x0F")) + imod.header.add(SAddU32(dst=sgpr("KrsOobS", 1, False), src0=sgpr(f"Srd{tc}+2"), src1=1, + comment=f"KRS: oobS = Srd{tc}.limit+1 (OOB sentinel)")) + imod.header.add(VMovB32(dst=vgpr("KrsOobV", 1, False), src=sgpr("KrsOobS", 1, False), + comment="KRS: oobV = oobS")) + # tailStartChunk = ceil(KRingShift / chunkElems) (for non-divisible KRingShift) + if ceilBias != 0: + imod.header.add(SAddU32(dst=sgpr("KrsTailStartChunk", 1, False), src0=sgpr("KRingShift"), src1=ceilBias, + comment=f"KRS: tailStartChunk = ceil(KRingShift/{chunkElems}) ; bias=+{ceilBias} elems")) + imod.header.add(SLShiftRightB32(dst=sgpr("KrsTailStartChunk", 1, False), shiftHex=hex(chunkElemShift), src=sgpr("KrsTailStartChunk", 1, False), + comment=f"KRS: tailStartChunk >>= {chunkElemShift} (chunkElems={chunkElems})")) + else: + imod.header.add(SMovB32(dst=sgpr("KrsTailStartChunk", 1, False), src=sgpr("KRingShift"), + comment="KRS: tailStartChunk = KRingShift (chunkElems==1)")) + imod.header.add(SAndB32(dst=sgpr("KrsMainLoopBytes", 1, False), src0=sgpr(f"SizesSum+{self.states.unrollIdx}"), src1=f"0x{maskDU:08x}", + comment="KRS: mainLoopElems = SizeK & ~(DepthU-1)")) + imod.header.add(SLShiftLeftB32(dst=sgpr("KrsMainLoopBytes", 1, False), shiftHex=hex(bpeShift), src=sgpr("KrsMainLoopBytes", 1, False), + comment="KRS: mainLoopBytes = mainLoopElems << log2(bpe)")) + imod.header.add(SLShiftLeftB32(dst=sgpr("KrsKRingShiftBytes", 1, False), shiftHex=hex(bpeShift), src=sgpr("KRingShift"), + comment=f"KRS: KRingShiftBytes = KRingShift * bpe (bpeBytes={bpeBytes})")) + imod.header.add(VCmpLtU32(dst=sgpr("KrsMaskValid", 2, False), src0=vgpr("KrsChunk", 1, False), src1=sgpr("KrsNumChunk", 1, False), + comment="KRS: maskValid = (chunk < numChunk)")) + imod.header.add(VCmpLtU32(dst=sgpr("KrsMaskHead", 2, False), src0=vgpr("KrsChunk", 1, False), src1=sgpr("KrsTailStartChunk", 1, False), + comment="KRS: maskLT = (chunk < tailStartChunk)")) + imod.header.add(SAndB64(dst=sgpr("KrsMaskHead", 2, False), src0=sgpr("KrsMaskHead", 2, False), src1=sgpr("KrsMaskValid", 2, False), + comment="KRS: maskHead = maskLT & maskValid")) + else: + krTmpS = None + krTmpV = None + directToLdsLoads = 0 instOffset = 0 prevLdsOffset = 0 @@ -8743,6 +9098,23 @@ def globalReadBody(tP): self.globalread_gpr_record.b.offset.append(soffset) useBuffer = not isTr + + # KRS: just-in-time patch for this offset register before issuing the load. + if krTailJIT: + loadModule.addComment0(f"KRS: inline tail-offset apply before load for {tc} offsets") + + # offsetVgpr is "GlobalReadOffset{tc}+{graIdx}" in this path. + loadModule.add(VSubU32(dst=vgpr(offsetVgpr), src0=vgpr(offsetVgpr), src1=sgpr("KrsKRingShiftBytes", 1, False), + comment="KRS: offset -= KRingShiftBytes")) + loadModule.add(SMovB64(dst=VCC(), src=sgpr("KrsMaskValid", 2, False), comment="KRS: vcc = maskValid")) + loadModule.add(VCndMaskB32(dst=vgpr(offsetVgpr), src0=vgpr("KrsOobV", 1, False), src1=vgpr(offsetVgpr), src2=VCC(), + comment="KRS: set OOB lanes to sentinel")) + loadModule.add(SMovB64(dst=VCC(), src=sgpr("KrsMaskHead", 2, False), comment="KRS: vcc = maskHead")) + loadModule.add(VAddU32(dst=vgpr("KrsScratch", 1, False), src0=vgpr(offsetVgpr), src1=sgpr("KrsMainLoopBytes", 1, False), + comment="KRS: offsetPlusMainLoopBytes = offset + mainLoopBytes")) + loadModule.add(VCndMaskB32(dst=vgpr(offsetVgpr), src0=vgpr("KrsScratch", 1, False), src1=vgpr(offsetVgpr), src2=VCC(), + comment="KRS: head keep offset; tail apply +mainLoopBytes")) + loadModule.add( self.chooseGlobalRead(useBuffer, \ bpl, destVgpr=destVgpr, \ addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 2 if isTr else 4), \ @@ -8770,6 +9142,11 @@ def globalReadBody(tP): glc=isGlc, slc=isSlc, nt=isNT, lds=isLds, \ hi16=(kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()) and loopCnt%2==1, \ comment="G -> Reg %u_%u_%u_%u"%(para, sPara, perp, sPerp ))) + # Release JIT temp regs after emitting all loads for this tensor. + if krTailJIT: + self.vgprPool.checkIn(krTmpV) + self.sgprPool.checkIn(krTmpS) + if kernel["ProblemType"]["Sparse"] and kernel["DirectToVgprSparseMetadata"]: if tP["is_sparse"]: @@ -10158,19 +10535,62 @@ def computeStoreSrdStart(self, kernel, srdTcList: list, sgprBpeList = [], useSiz else: useSize = [False for _ in srdTcList] - with self.allocTmpSgpr(3) as tmpSgprInfo: + with self.allocTmpSgpr(9) as tmpSgprInfo: tmpS0 = tmpSgprInfo.idx tmpS1 = tmpS0+1 wgMT1 = tmpS0+2 - - # Compute and save wg1*MT1 - the element offset that is top of the macro-tile in output space + sRem = tmpS0+3 + sNeg = tmpS0+4 + sLow = tmpS0+5 + sG = tmpS0+6 + sMask = tmpS0+7 + sTiles = tmpS0+8 + + # Compute and save the element offset for Index1 in output space. + # Default is wg1*MT1. Interleave (if enabled) replaces this with baseCol: + # baseCol = (wg1/G)*(MT1*G) + (wg1%G), G=min(lowbit(SizeJ/MT1), LVCB) assert kernel["BufferStore"] module.addSpaceLine() - module.add(SMulI32( - dst=sgpr(wgMT1), \ - src0="MT1", \ - src1=sgpr("WorkGroup1"), \ - comment="<- wg1*MT1")) + module.add(SMulI32(dst=sgpr(wgMT1), src0="MT1", src1=sgpr("WorkGroup1"), comment="<- wg1*MT1")) + + if kernel["BAddrInterleave"]: + labelCase16 = Label(self.labels.getNameInc("BInterleave_storeSrd_case16"), "") + labelCase8 = Label(self.labels.getNameInc("BInterleave_storeSrd_case8"), "") + labelCase4 = Label(self.labels.getNameInc("BInterleave_storeSrd_case4"), "") + labelCase2 = Label(self.labels.getNameInc("BInterleave_storeSrd_case2"), "") + labelAfterShift = Label(self.labels.getNameInc("BInterleave_storeSrd_afterShift"), "") + + module.add(SSubU32(dst=sgpr(sMask), src0=sgpr("BInterleaveG"), src1=1, comment="mask=G-1")) + module.add(SAndB32(dst=sgpr(tmpS1), src0=sgpr("WorkGroup1"), src1=sgpr(sMask), comment="phase = wg1 & (G-1)")) + + # super = wg1 >> log2(G) (G in {2,4,8,16}) + module.add(SMovB32(dst=sgpr(tmpS0), src=sgpr("WorkGroup1"), comment="super = wg1")) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=16)) + module.add(SCBranchSCC1(labelName=labelCase16.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=8)) + module.add(SCBranchSCC1(labelName=labelCase8.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=4)) + module.add(SCBranchSCC1(labelName=labelCase4.getLabelName())) + module.add(SCmpEQU32(src0=sgpr("BInterleaveG"), src1=2)) + module.add(SCBranchSCC1(labelName=labelCase2.getLabelName())) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase16) + module.add(SLShiftRightB32(dst=sgpr(tmpS0), src=sgpr(tmpS0), shiftHex=hex(4), comment="super = wg1>>4")) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase8) + module.add(SLShiftRightB32(dst=sgpr(tmpS0), src=sgpr(tmpS0), shiftHex=hex(3), comment="super = wg1>>3")) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase4) + module.add(SLShiftRightB32(dst=sgpr(tmpS0), src=sgpr(tmpS0), shiftHex=hex(2), comment="super = wg1>>2")) + module.add(SBranch(labelName=labelAfterShift.getLabelName())) + module.add(labelCase2) + module.add(SLShiftRightB32(dst=sgpr(tmpS0), src=sgpr(tmpS0), shiftHex=hex(1), comment="super = wg1>>1")) + module.add(labelAfterShift) + + # wgMT1 = baseCol = phase + super*(MT1*G) + module.add(SMulI32(dst=sgpr(wgMT1), src0=sgpr("BInterleaveG"), src1="MT1", comment="MT1*G")) + module.add(SMulI32(dst=sgpr(wgMT1), src0=sgpr(wgMT1), src1=sgpr(tmpS0), comment="super*(MT1*G)")) + module.add(SAddU32(dst=sgpr(wgMT1), src0=sgpr(wgMT1), src1=sgpr(tmpS1), comment="baseCol = super*(MT1*G)+phase")) # Overall strategy is to set the SRD to the top-left of the macro-tile. # TT offsets are from this base (and include the column) @@ -10276,6 +10696,34 @@ def computeStoreSrdStart(self, kernel, srdTcList: list, sgprBpeList = [], useSiz def computeStoreVgprs(self, kernel, divisor=None, tid0Scale=None, tid1Scale=None): module = Module("computeStoreVgprs") module.addComment0("computeStoreVgprs") + + # Interleave (restricted): scale StrideC1J/StrideD1J by runtime G before store vgpr math, + # while SRD base was already computed using unscaled strides. + if kernel["BAddrInterleave"] and kernel["EnableMatrixInstruction"] and kernel["BufferStore"]: + packedC1 = kernel["PackedC1IndicesX"] + strideC1 = "StrideC%s" % (self.states.indexChars[packedC1[0]]) + strideD1 = "StrideD%s" % (self.states.indexChars[packedC1[0]]) + + # IMPORTANT: do NOT overwrite the original stride SGPRs in-place. + # This kernel uses a persistent loop; the next iteration re-enters at label_PersistentLoopStart + # and expects the original StrideC/D values loaded from kernargs. If we multiply in-place, + # the stride would compound each persistent iteration. + # + # Instead, compute scaled strides into temp SGPRs, then re-alias sgprStrideC/D for the + # remainder of the store path to refer to the scaled temporaries. + tmpStrideC1 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpStrideD1 = self.sgprPool.checkOut(1, preventOverflow=False) + module.add(SMulI32(dst=sgpr(tmpStrideC1), src0=sgpr(strideC1), src1=sgpr("BInterleaveG"), comment="StrideC1*G (temp)")) + module.add(SMulI32(dst=sgpr(tmpStrideD1), src0=sgpr(strideD1), src1=sgpr("BInterleaveG"), comment="StrideD1*G (temp)")) + + # Re-alias the stride symbols for the remainder of the emitted store code. + # Note: this is an assembler-time alias; it does not mutate runtime SGPR contents. + module.add(ValueSet(name=f"sgpr{strideC1}", value=tmpStrideC1, format=-1)) + module.add(ValueSet(name=f"sgpr{strideD1}", value=tmpStrideD1, format=-1)) + + # BAddrInterleave: no further uses after this point; free the sgpr alias for cleaner asm/debug. + module.add(ValueSet(name="sgprBInterleaveG", value="UNDEF", format=-1)) + component = Component.ComputeStoreVgprs.find(self) if component: if kernel["EnableMatrixInstruction"]: diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index c0ae3946f81e..6f59bb4cfdb5 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -957,6 +957,29 @@ def assignDerivedParameters( Solution.assignProblemIndependentDerivedParameters(state, printRejectionReason, isaInfoMap) + # KRingShift currently only supported for TN (A transposed, B not transposed). + # Disallow enabling KRingShift on NN/NT/TT until those paths are validated. + if state["KRingShift"]: + ta = int(state["ProblemType"]["TransposeA"]) + tb = int(state["ProblemType"]["TransposeB"]) + if not (ta == 1 and tb == 0): + reject(state, printRejectionReason, f"KRingShift requires TN (TransposeA=1, TransposeB=0); got TransposeA={ta}, TransposeB={tb}") + return + + # KRingShift is defined to operate only in conjunction with BAddrInterleave (BInterleaveG). + # If BAddrInterleave is not enabled, do not allow KRingShift to be enabled. + if state["KRingShift"] and (not state["BAddrInterleave"]): + reject(state, printRejectionReason, "KRingShift requires BAddrInterleave (BInterleaveG)") + return + + # BAddrInterleave runtime restriction (host-side predicate, not codegen): + # Match the kernel's initBInterleaveG enable conditions: + # - require tiles1 = SizeJ / MT1 to be an integer (SizeJ % MT1 == 0) + # - require lowbit(tiles1) > 1 so that G=min(lowbit(tiles1), LVCB) is > 1 (enabled) + # Note: if lowbit(tiles1) == 1, then G==1 and the kernel disables BAddrInterleave. + if state["BAddrInterleave"]: + state["AssertFree1DivByMT1LowbitGT1"] = state["MacroTile1"] + if state["UseDirect32XEmulation"] == True: # Turn off Direct32X for the following kernels: # Cijk_Ailk_Bjlk_S_MX_B_Bias_HA_S_SAV_UserArgs_MT16x16x512_MI16x16x1 @@ -2493,6 +2516,30 @@ def calSwizzlePackK(state, tc): state["LVCB"] = roundupRatio(state["LSCB"] , state["GlobalReadVectorWidthB"]) state["LVPB"] = roundupRatio(state["LSPB"] , state["GlobalReadVectorWidthB"]) + # KRingShift wrap handling exists only in the tail loop. + # If (k + KRingShift) would wrap inside the main loop, the kernel will be incorrect (no main-loop wrap fix). + # Enforce a host-side runtime predicate which guarantees any KRS wrap happens only in tail. + # + # NOTE: This must be encoded after LVCB/LSCB are computed above. + if state["KRingShift"]: + # Pack predicate args (see ContractionProblemPredicates.hpp::KRingShiftTailWrapOnly): + # [63:48]=cacheLineBytes, [47:32]=depthU, [31:16]=mt1, [15:8]=lvcb, [7:0]=bpeB + cacheLineBytes = int(isaInfoMap[isa].archCaps.get("vL1DCacheLineBytes", 0)) + depthU = int(state["DepthU"]) + mt1 = int(state["MacroTile1"]) + bpeB = int(state["ProblemType"]["DataTypeB"].numBytes()) + lvcb = int(state["LVCB"]) + + if (0 < cacheLineBytes < (1<<16)) and (0 < depthU < (1<<16)) and (0 < mt1 < (1<<16)) and (0 < lvcb < (1<<8)) and (0 < bpeB < (1<<8)): + state["AssertKRingShiftTailWrapOnly"] = (cacheLineBytes << 48) | (depthU << 32) | (mt1 << 16) | (lvcb << 8) | bpeB + else: + reject(state, printRejectionReason, + f"KRingShift requires encodable AssertKRingShiftTailWrapOnly predicate " + f"(cacheLineBytes={cacheLineBytes}, depthU={depthU}, mt1={mt1}, " + f"lvcb={lvcb}, lscb={state.get('LSCB', None)}, grvwB={state.get('GlobalReadVectorWidthB', None)}, " + f"bpeB={bpeB})") + return + if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: state["LVCMetadata"] = roundupRatio(state["LSCMetadata"] , state["GlobalReadVectorWidthMetadata"]) state["LVPMetadata"] = roundupRatio(state["LSPMetadata"] , state["GlobalReadVectorWidthMetadata"]) diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/kringshift.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/kringshift.yaml new file mode 100644 index 000000000000..be6e5c05f097 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/kringshift.yaml @@ -0,0 +1,198 @@ +TestParameters: + # Restrict to gfx950 for now (feature was developed/validated there). + marks: [skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx940, skip-gfx941, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] + +GlobalParameters: + MinimumRequiredVersion: 5.0.0 + SleepPercent: 50 + NumElementsToValidate: 128 + PrintSolutionRejectionReason: True + DataInitTypeBeta: 0 + DataInitTypeAlpha: 1 + CSVExportWinner: 1 + CSVMergeSameProblemID: 1 + Device: 0 + +BenchmarkProblems: + ######################################## + # NOTE: + # KRingShift is restricted to TN in Solution.py (host-side rejection for NN/NT/TT). + # Keep this test file TN-only. + ######################################## + + ######################################## + # BF16 TN - KRingShift coverage + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: b + DestDataType: b + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + UseBias: 0 + Activation: False + UseScaleAlphaVec: 0 + - # BenchmarkProblemSizeGroup - single focused config + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - KRingShift: [True] + # KRingShift requires BAddrInterleave (BInterleaveG). + - BAddrInterleave: [True] + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 3, 1, 4] + - DepthU: [128] + - StreamK: [3] + - ScheduleGlobalRead: [1] + - ScheduleLocalWrite: [1] + - GlobalSplitU: [0] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - BufferLoad: [True] + - BufferStore: [True] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - VectorWidthA: [8] + - VectorWidthB: [1] + - ClusterLocalRead: [0] + - SourceSwap: [0, 1] + - WorkGroup: + - [16, 16, 1] + - WavefrontSize: [64] + - ScheduleIterAlg: [3] + - TransposeLDS: [2] + - 1LDSBuffer: [1] + - StoreVectorWidth: [8] + - StoreSyncOpt: [0] + - StaggerU: [16] + - StaggerUStride: [256] + - StaggerUMapping: [0] + - UseSgprForGRO: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[2048], [3072], [1], [1880, 32, 2176]] + - Range: [[128], [192], [1], [128, 32, 512]] + + ######################################## + # FP16 TN - KRingShift coverage + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + UseBias: 0 + Activation: False + UseScaleAlphaVec: 0 + - # BenchmarkProblemSizeGroup - single focused config + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - KRingShift: [True] + - BAddrInterleave: [True] + - MatrixInstruction: + - [16, 16, 32, 1, 1, 8, 3, 1, 4] + - DepthU: [128] + - StreamK: [3] + - ScheduleGlobalRead: [1] + - ScheduleLocalWrite: [1] + - GlobalSplitU: [0] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - BufferLoad: [True] + - BufferStore: [True] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [8] + - GlobalReadVectorWidthB: [8] + - LocalReadVectorWidth: [8] + - VectorWidthA: [8] + - VectorWidthB: [1] + - ClusterLocalRead: [0] + - SourceSwap: [0, 1] + - WorkGroup: + - [16, 16, 1] + - WavefrontSize: [64] + - ScheduleIterAlg: [3] + - TransposeLDS: [2] + - 1LDSBuffer: [1] + - StoreVectorWidth: [8] + - StoreSyncOpt: [0] + - StaggerU: [16] + - StaggerUStride: [256] + - StaggerUMapping: [0] + - UseSgprForGRO: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[2048], [3072], [1], [1880, 32, 2176]] + - Range: [[128], [192], [1], [128, 32, 512]] + + ######################################## + # FP8 (F8->F16) TN - KRingShift coverage + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: f8 + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + UseBias: 0 + Activation: False + UseScaleAlphaVec: 0 + - # BenchmarkProblemSizeGroup - single focused config + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - KRingShift: [True] + - BAddrInterleave: [True] + - MatrixInstruction: + - [16, 16, 128, 1, 1, 4, 4, 2, 2] + - DepthU: [128] + - StreamK: [3] + - ScheduleIterAlg: [3] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - BufferLoad: [True] + - BufferStore: [True] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [8] + - WorkGroup: + - [16, 16, 1] + - WavefrontSize: [64] + - TransposeLDS: [2] + - 1LDSBuffer: [1] + - SourceSwap: [1] + - StaggerU: [16] + - StaggerUStride: [256] + - StaggerUMapping: [0] + - UseSgprForGRO: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Range: [[2048], [3072], [1], [1880, 32, 2176]] + - Range: [[128], [192], [1], [128, 32, 512]] + diff --git a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp index 6c789aac3e32..9252f6bd3735 100755 --- a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp @@ -133,6 +133,186 @@ namespace TensileLite } }; + // Address-interleave restriction: + // Require tiles1 (=Free1Size/value) to have lowbit(tiles1) > 1 (i.e. even), + // and require Free1Size % value == 0. + // This matches the kernel's initBInterleaveG enable condition that G=min(lowbit, LVCB) must be > 1. + struct Free1SizeDivByValueLowbitGT1 + : public Predicate_CRTP + { + enum + { + HasIndex = true, + HasValue = true + }; + size_t index; + size_t value; + + Free1SizeDivByValueLowbitGT1() = default; + Free1SizeDivByValueLowbitGT1(size_t index, size_t value) + : index(index) + , value(value) + { + } + + static std::string Type() + { + return "Free1SizeDivByValueLowbitGT1"; + } + + virtual bool operator()(ContractionProblemGemm const& problem) const override + { + if(value == 0) + return false; + size_t freeSize = (!problem.transposeC01() ? problem.freeSizeB(index) + : problem.freeSizeA(index)); + if(freeSize % value != 0) + return false; + size_t tiles1 = freeSize / value; + if(tiles1 == 0) + return false; + size_t lowbit = tiles1 & (~tiles1 + 1); // tiles1 & -tiles1 + return lowbit > 1; + } + + virtual bool debugEval(ContractionProblemGemm const& problem, + std::ostream& stream) const override + { + size_t freeSize = (!problem.transposeC01() ? problem.freeSizeB(index) + : problem.freeSizeA(index)); + bool okDiv = (value != 0) && (freeSize % value == 0); + size_t tiles1 = okDiv ? (freeSize / value) : 0; + size_t lowbit = tiles1 ? (tiles1 & (~tiles1 + 1)) : 0; + bool ok = okDiv && (lowbit > 1); + return debugEvalCmp(problem, + stream, + "free1", + freeSize, + "%", + "value", + value, + "lowbit", + lowbit, + ">", + "1", + size_t(1)) && ok; + } + }; + + // KRingShift wrap restriction: + // Require that any (k + KRingShift) wrap occurs only in tail loop (no main-loop wrap fix). + // + // We model the exact KRS enable/shift computation used in initKRingShift: + // rem = StrideB1J % cacheLineElements + // shift = (-WorkGroup1 * rem) mod cacheLineElements + // + // mainLoopElems = k - (k % DepthU) + // tailSize = k % DepthU + // + // To guarantee no wrap in main loop for any WG1, require: + // maxShift(wg1 in [0, tiles1-1]) <= tailSize + // + // Packed value format (size_t): + // [63:48]=cacheLineBytes, [47:32]=depthU, [31:16]=mt1, [15:8]=lvcb, [7:0]=bpeB + struct KRingShiftTailWrapOnly + : public Predicate_CRTP + { + enum + { + HasIndex = true, + HasValue = true + }; + ssize_t index; + size_t value; + + KRingShiftTailWrapOnly() = default; + KRingShiftTailWrapOnly(size_t index, size_t value) + : index(static_cast(index)) + , value(value) + { + } + + static std::string Type() + { + return "KRingShiftTailWrapOnly"; + } + + virtual bool operator()(ContractionProblemGemm const& problem) const override + { + if(value == 0) + return false; + + size_t bpeB = (value & 0xFFu); + size_t lvcb = ((value >> 8) & 0xFFu); + size_t mt1 = ((value >> 16) & 0xFFFFu); + size_t depthU = ((value >> 32) & 0xFFFFu); + size_t cacheLineByte = ((value >> 48) & 0xFFFFu); + + if(mt1 == 0 || lvcb == 0 || bpeB == 0 || depthU == 0 || cacheLineByte == 0) + return false; + size_t cacheLineElems = cacheLineByte / bpeB; + // Compute runtime G from Free1 size and MT1. + size_t free1 = (!problem.transposeC01() ? problem.freeSizeB(0) + : problem.freeSizeA(0)); + if(free1 % mt1 != 0) + return false; + size_t tiles1 = free1 / mt1; + if(tiles1 == 0) + return false; + size_t lowbit = tiles1 & (~tiles1 + 1); // tiles1 & -tiles1 + size_t G = (lowbit <= lvcb) ? lowbit : lvcb; + if(G <= 1) + return false; + + // StrideB1J: stride of B along Free1 (J) in elements. + auto const& freeB = problem.freeIndicesB(); + if(freeB.empty()) + return false; + size_t bDim = freeB[0].i; + if(bDim >= problem.b().strides().size()) + return false; + size_t strideB1J = problem.b().strides()[bDim]; + + // Determine whether KRS shift can be enabled at runtime. + size_t mask = cacheLineElems - 1; + size_t rem = strideB1J & mask; // mod cacheLineElems (pow2) + + // K is the (last) bound index (typically the summation dimension). + size_t k = 0; + if(index < 0) + k = problem.boundSize(problem.boundIndices().size() + index); + else + k = problem.boundSize(index); + + // tailSize = k % depthU. If depthU>k, tailSize=k and mainloop is empty => safe. + size_t tailSize = (depthU != 0) ? (k % depthU) : 0; + + // Compute maxShift over wg1 in [0, tiles1-1] (cycle length <= cacheLineElems <= 256 typically, <=64 on gfx950). + size_t wgCount = tiles1; + size_t limit = (wgCount < cacheLineElems) ? wgCount : cacheLineElems; + size_t maxShift = 0; + for(size_t wg1 = 0; wg1 < limit; ++wg1) + { + // shift = (-wg1*rem) mod cacheLineElems + size_t prod = (wg1 * rem) & mask; + size_t shift = (prod == 0) ? 0 : ((cacheLineElems - prod) & mask); + if(shift > maxShift) + maxShift = shift; + } + + return tailSize >= maxShift; + } + + virtual bool debugEval(ContractionProblemGemm const& problem, + std::ostream& stream) const override + { + // Reuse operator() and print a single boolean result for simplicity. + bool ok = (*this)(problem); + return debugEvalCmp(problem, stream, "KRingShiftTailWrapOnly", size_t(ok), "==", "true", size_t(1)) + && ok; + } + }; + struct BatchSizeMultiple : public Predicate_CRTP { diff --git a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/ContractionPredicates.hpp b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/ContractionPredicates.hpp index 7cd3df440376..6701fb2ec27b 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/ContractionPredicates.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/ContractionPredicates.hpp @@ -58,6 +58,8 @@ namespace TensileLite SubclassMap rv( {Base::template Pair(), Base::template Pair(), + Base::template Pair(), + Base::template Pair(), Base::template Pair(), Base::template Pair(), Base::template Pair(), @@ -156,6 +158,18 @@ namespace TensileLite { }; + template + struct MappingTraits + : public AutoMappingTraits + { + }; + + template + struct MappingTraits + : public AutoMappingTraits + { + }; + template struct MappingTraits : public AutoMappingTraits diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/container.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/container.hpp index 693c925339aa..80bd15c870a6 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/container.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/container.hpp @@ -942,7 +942,8 @@ namespace rocisa else { return minusStr + regType + "[" + macroSlash + regType + "gpr" - + regName->toString() + ":" + regType + "gpr" + regName->toString() + "+" + + regName->toString() + ":" + macroSlash + regType + "gpr" + + regName->toString() + "+" + std::to_string(regNum - 1) + "]" + absStr; } } diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/hardware_caps.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/hardware_caps.hpp index 79a3c0d3a62b..626e96a28a37 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/hardware_caps.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/hardware_caps.hpp @@ -404,6 +404,19 @@ inline std::map initArchCaps(const IsaVersion& isaVersion) rv["VOP3ByteSel"] = isaVersion[0] == 12; rv["HasFP8_OCP"] = isaVersion[0] == 12; rv["HasF32XEmulation"] = checkInList(isaVersion, {{9, 5, 0}}); + + // Vector L1 Data cache line size (bytes) used for alignment-sensitive optimizations in codegen. + // NOTE: This is a *codegen-time* (compile-time) constant selected by target ISA. + // + // Per project convention: + // - MI100 (gfx908 / ISA 9.0.8) : 64B + // - MI200 (gfx90a / ISA 9.0.10): 64B + // - Others : 128B + int vL1DCacheLineBytes = 128; + if(checkInList(isaVersion, {{9, 0, 8}, {9, 0, 10}})) + vL1DCacheLineBytes = 64; + rv["vL1DCacheLineBytes"] = vL1DCacheLineBytes; + return rv; }