@@ -90,12 +90,11 @@ struct RockGridwiseGemmToBlockwisePass
9090// This function will process a tile of gemm input into LDS (or register)
9191// buffer in a way it could be fed to blockwise_gemm_accel op
9292static void loadAndStoreGemmInputTile (
93- Location loc, Value in, Value kIter , Value tid,
93+ PatternRewriter &rewriter, Location loc, Value in, Value kIter , Value tid,
9494 rock::layout::GridCoordinates gridCoords, Value destLDS, Value destRegs,
9595 GemmLoadTileType loadType, StringRef nonKDimName, uint32_t blockSize,
96- Type elementTypeA, Type elementTypeALoad, Type elementTypeB,
97- Type elementTypeBLoad, int64_t G, int64_t M, int64_t N,
98- PatternRewriter &rewriter,
96+ Type elementTypeA, Type elementTypeB, Type elementType,
97+ Type elementLoadType, int64_t G, int64_t M, int64_t N,
9998 const RockAccelTuningParamAttrInterface &gemmTuningParams,
10099 const GemmFeaturesAttr &featuresAttr,
101100 const LDSLayoutConfigDim &ldsLayoutCfg) {
@@ -114,7 +113,7 @@ static void loadAndStoreGemmInputTile(
114113 BlockwiseLoadTileOp::create (
115114 rewriter, loc, in, destLDS, destRegs, loadTypeAttr, isA,
116115 TypeAttr::get (elementTypeA), TypeAttr::get (elementTypeB),
117- TypeAttr::get (elementTypeALoad ), TypeAttr::get (elementTypeBLoad ),
116+ TypeAttr::get (elementType ), TypeAttr::get (elementLoadType ),
118117 rotateWithKAttr, swapThreadIterSubDimsAttr, ldsLayoutDxKAttr,
119118 ValueRange{kIter , gridCoords.g_block , gridCoords.m_block ,
120119 gridCoords.n_block , tid},
@@ -2195,10 +2194,10 @@ struct GridwiseAttentionAccelRewritePattern
21952194 createLDSByteBuffer (rewriter, loc, ldsByteBufferQSize, elemTypeQ);
21962195 }
21972196 loadAndStoreGemmInputTile (
2198- loc, inQ, /* kiter=*/ zero, tid, gridCoordsGemm0LoadQ, ldsByteBufferQ ,
2199- preAccelRegBuffersQ, loadTypeQ, " n" , blockSize, elemTypeK ,
2200- elemTypeKLoad , elemTypeQ, elemTypeQLoad, gemm0G, gemm0M, gemm0N ,
2201- rewriter , gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
2197+ rewriter, loc, inQ, /* kiter=*/ zero, tid, gridCoordsGemm0LoadQ,
2198+ ldsByteBufferQ, preAccelRegBuffersQ, loadTypeQ, " n" , blockSize,
2199+ elemTypeK , elemTypeQ, elemTypeQ, elemTypeQLoad, gemm0G, gemm0M,
2200+ gemm0N , gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
22022201 }
22032202
22042203 bool dynamicMLoop = splitKV != 1 || isCausal || isKVCache;
@@ -2260,21 +2259,20 @@ struct GridwiseAttentionAccelRewritePattern
22602259 TypedValue<MemRefType> ldsTileBufferQ;
22612260 if (gemm0K != gemm0KPerBlock) {
22622261 loadAndStoreGemmInputTile (
2263- loc, inQ, kLoopIV , tid, gridCoordsGemm0, ldsByteBufferQ,
2262+ rewriter, loc, inQ, kLoopIV , tid, gridCoordsGemm0, ldsByteBufferQ,
22642263 preAccelRegBuffersQ, GemmLoadTileType::DoubleBuffer, " n" ,
2265- blockSize, elemTypeK, elemTypeKLoad, elemTypeQ, elemTypeQLoad,
2266- gemm0G, gemm0M, gemm0N, rewriter, gemm0TuningParams, featuresAttr,
2267- ldsLayoutCfgNG0);
2264+ blockSize, elemTypeK, elemTypeQ, elemTypeQ, elemTypeQLoad, gemm0G,
2265+ gemm0M, gemm0N, gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
22682266 ldsTileBufferQ =
22692267 viewBufferAs (rewriter, ldsByteBufferQ,
22702268 vectorTypeOrSelf (elemTypeQ, gemm0kpack));
22712269 }
22722270
22732271 loadAndStoreGemmInputTile (
2274- loc, inK, kLoopIV , tid, gridCoordsGemm0, ldsByteBufferK,
2272+ rewriter, loc, inK, kLoopIV , tid, gridCoordsGemm0, ldsByteBufferK,
22752273 preAccelRegBufferK, GemmLoadTileType::Default, " m" , blockSize,
2276- elemTypeK, elemTypeKLoad, elemTypeQ, elemTypeQLoad , gemm0G, gemm0M,
2277- gemm0N, rewriter, gemm0TuningParams, featuresAttr, ldsLayoutCfgMG0);
2274+ elemTypeK, elemTypeQ, elemTypeK, elemTypeKLoad , gemm0G, gemm0M,
2275+ gemm0N, gemm0TuningParams, featuresAttr, ldsLayoutCfgMG0);
22782276 TypedValue<MemRefType> ldsTileBufferK = viewBufferAs (
22792277 rewriter, ldsByteBufferK, vectorTypeOrSelf (elemTypeK, gemm0kpack));
22802278 // LDS barrier.
@@ -2521,12 +2519,11 @@ struct GridwiseAttentionAccelRewritePattern
25212519 }
25222520
25232521 loadAndStoreGemmInputTile (
2524- loc, inV,
2522+ rewriter, loc, inV,
25252523 /* kIter=*/ mLoopIV , tid, gridCoordsGemm1, ldsByteBufferV,
25262524 preAccelRegBufferV, GemmLoadTileType::Default, " m" , blockSize,
2527- elemTypeV, elemTypeVLoad, elemTypeV, elemTypeVLoad, gemm0G,
2528- gemm1M, gemm1N, rewriter, gemm1TuningParams, featuresAttr,
2529- ldsLayoutCfgMG1);
2525+ elemTypeV, elemTypeV, elemTypeV, elemTypeVLoad, gemm0G, gemm1M,
2526+ gemm1N, gemm1TuningParams, featuresAttr, ldsLayoutCfgMG1);
25302527 TypedValue<MemRefType> ldsTileBufferV =
25312528 viewBufferAs (rewriter, ldsByteBufferV,
25322529 vectorTypeOrSelf (elemTypeV, gemm1kpack));
@@ -2956,15 +2953,15 @@ struct GridwiseGemmAccelRewritePattern
29562953
29572954 // Load from global memory to LDS
29582955 loadAndStoreGemmInputTile (
2959- loc, matB, /* kiter=*/ iv, tid, gridCoords, ldsByteBufferB,
2960- arrayBForLoad, loadType, " n" , blockSize, elementTypeA,
2961- elementTypeALoad, elementTypeB, elementTypeBLoad, G, M, N, b ,
2962- op. getParamsAttr (), featuresAttr, ldsLayoutConfigB);
2956+ b, loc, matB, /* kiter=*/ iv, tid, gridCoords, ldsByteBufferB,
2957+ arrayBForLoad, loadType, " n" , blockSize, elementTypeA, elementTypeB,
2958+ elementTypeB, elementTypeBLoad, G, M, N, op. getParamsAttr () ,
2959+ featuresAttr, ldsLayoutConfigB);
29632960 loadAndStoreGemmInputTile (
2964- loc, matA, /* kiter=*/ iv, tid, gridCoords, ldsByteBufferA,
2965- arrayAForLoad, loadType, " m" , blockSize, elementTypeA,
2966- elementTypeALoad, elementTypeB, elementTypeBLoad, G, M, N, b ,
2967- op. getParamsAttr (), featuresAttr, ldsLayoutConfigA);
2961+ b, loc, matA, /* kiter=*/ iv, tid, gridCoords, ldsByteBufferA,
2962+ arrayAForLoad, loadType, " m" , blockSize, elementTypeA, elementTypeB,
2963+ elementTypeA, elementTypeALoad, G, M, N, op. getParamsAttr () ,
2964+ featuresAttr, ldsLayoutConfigA);
29682965
29692966 // Emit blockwise GEMM. This will load data from LDS (or registers) and
29702967 // compute the MMA at the same time
0 commit comments