Skip to content

Commit 7f329ac

Browse files
authored
Use explicit elementType in blockwiseLoadTile (#2078)
* Use explicit elementType in blockwiseLoadTile
1 parent 58c991b commit 7f329ac

File tree

6 files changed

+45
-46
lines changed

6 files changed

+45
-46
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,12 +1411,12 @@ def Rock_BlockwiseLoadTileOp
14111411
Arg<Optional<MemRefOf<NativeMemoryOpTypes>>,
14121412
"destination registers">:$destRegisters,
14131413
Rock_GemmLoadTileTypeAttr:$loadType, UnitAttr:$isA,
1414-
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB,
1415-
TypeAttr:$elementTypeALoad, TypeAttr:$elementTypeBLoad,
1416-
UnitAttr:$rotateWithK, UnitAttr:$swapThreadIterSubDims,
1417-
UnitAttr:$LDSLayoutDxK, Variadic<Index>:$sourceIndices, I64Attr:$G,
1418-
I64Attr:$M, I64Attr:$N, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
1419-
I32Attr:$blockSize, RockAccelTuningParamAttrInterface:$params)> {
1414+
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB, TypeAttr:$elementType,
1415+
TypeAttr:$elementLoadType, UnitAttr:$rotateWithK,
1416+
UnitAttr:$swapThreadIterSubDims, UnitAttr:$LDSLayoutDxK,
1417+
Variadic<Index>:$sourceIndices, I64Attr:$G, I64Attr:$M, I64Attr:$N,
1418+
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
1419+
RockAccelTuningParamAttrInterface:$params)> {
14201420
let summary =
14211421
"Blockwise load tile from global memory to LDS and/or registers";
14221422
let description = [{
@@ -1429,6 +1429,9 @@ def Rock_BlockwiseLoadTileOp
14291429
- DoubleBuffer: Creates three stages, (1) load from memory, (2) write to LDS, (3) load to registers.
14301430

14311431
`isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes.
1432+
`elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs.
1433+
`elementLoadType` is the element type of the global buffer this BlockwiseLoadTileOp is trying to load.
1434+
`elementType` is the elementType in the registers. Note that it may differ from `elementLoadType` because of input fusions.
14321435
}];
14331436
let assemblyFormat = [{
14341437
$source (`[` $sourceIndices^ `]`)? (`LDS` `->` $destLDS^)? (`->` $destRegisters^)? attr-dict

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,7 @@ LogicalResult BlockwiseLoadTileOp::verify() {
22262226
}
22272227

22282228
SmallVector<mlir::Type> BlockwiseLoadTileOp::getTypesForFeature() {
2229-
return {getElementTypeA()};
2229+
return {getElementType()};
22302230
}
22312231

22322232
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ class LoweringBlockwiseLoadTileOp final
163163

164164
Type elementTypeA = op.getElementTypeA();
165165
Type elementTypeB = op.getElementTypeB();
166-
Type elementTypeLoad =
167-
isA ? op.getElementTypeALoad() : op.getElementTypeBLoad();
168-
Type elementType = isA ? elementTypeA : elementTypeB;
166+
Type elementTypeLoad = op.getElementLoadType();
167+
Type elementType = op.getElementType();
169168

170169
auto accelEmitterPtr = accel::AccelEmitter::select(
171170
features, elementTypeA, elementTypeB, arch, tuningParams);

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,11 @@ struct RockGridwiseGemmToBlockwisePass
9090
// This function will process a tile of gemm input into LDS (or register)
9191
// buffer in a way it could be fed to blockwise_gemm_accel op
9292
static void loadAndStoreGemmInputTile(
93-
Location loc, Value in, Value kIter, Value tid,
93+
PatternRewriter &rewriter, Location loc, Value in, Value kIter, Value tid,
9494
rock::layout::GridCoordinates gridCoords, Value destLDS, Value destRegs,
9595
GemmLoadTileType loadType, StringRef nonKDimName, uint32_t blockSize,
96-
Type elementTypeA, Type elementTypeALoad, Type elementTypeB,
97-
Type elementTypeBLoad, int64_t G, int64_t M, int64_t N,
98-
PatternRewriter &rewriter,
96+
Type elementTypeA, Type elementTypeB, Type elementType,
97+
Type elementLoadType, int64_t G, int64_t M, int64_t N,
9998
const RockAccelTuningParamAttrInterface &gemmTuningParams,
10099
const GemmFeaturesAttr &featuresAttr,
101100
const LDSLayoutConfigDim &ldsLayoutCfg) {
@@ -114,7 +113,7 @@ static void loadAndStoreGemmInputTile(
114113
BlockwiseLoadTileOp::create(
115114
rewriter, loc, in, destLDS, destRegs, loadTypeAttr, isA,
116115
TypeAttr::get(elementTypeA), TypeAttr::get(elementTypeB),
117-
TypeAttr::get(elementTypeALoad), TypeAttr::get(elementTypeBLoad),
116+
TypeAttr::get(elementType), TypeAttr::get(elementLoadType),
118117
rotateWithKAttr, swapThreadIterSubDimsAttr, ldsLayoutDxKAttr,
119118
ValueRange{kIter, gridCoords.g_block, gridCoords.m_block,
120119
gridCoords.n_block, tid},
@@ -2195,10 +2194,10 @@ struct GridwiseAttentionAccelRewritePattern
21952194
createLDSByteBuffer(rewriter, loc, ldsByteBufferQSize, elemTypeQ);
21962195
}
21972196
loadAndStoreGemmInputTile(
2198-
loc, inQ, /*kiter=*/zero, tid, gridCoordsGemm0LoadQ, ldsByteBufferQ,
2199-
preAccelRegBuffersQ, loadTypeQ, "n", blockSize, elemTypeK,
2200-
elemTypeKLoad, elemTypeQ, elemTypeQLoad, gemm0G, gemm0M, gemm0N,
2201-
rewriter, gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
2197+
rewriter, loc, inQ, /*kiter=*/zero, tid, gridCoordsGemm0LoadQ,
2198+
ldsByteBufferQ, preAccelRegBuffersQ, loadTypeQ, "n", blockSize,
2199+
elemTypeK, elemTypeQ, elemTypeQ, elemTypeQLoad, gemm0G, gemm0M,
2200+
gemm0N, gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
22022201
}
22032202

22042203
bool dynamicMLoop = splitKV != 1 || isCausal || isKVCache;
@@ -2260,21 +2259,20 @@ struct GridwiseAttentionAccelRewritePattern
22602259
TypedValue<MemRefType> ldsTileBufferQ;
22612260
if (gemm0K != gemm0KPerBlock) {
22622261
loadAndStoreGemmInputTile(
2263-
loc, inQ, kLoopIV, tid, gridCoordsGemm0, ldsByteBufferQ,
2262+
rewriter, loc, inQ, kLoopIV, tid, gridCoordsGemm0, ldsByteBufferQ,
22642263
preAccelRegBuffersQ, GemmLoadTileType::DoubleBuffer, "n",
2265-
blockSize, elemTypeK, elemTypeKLoad, elemTypeQ, elemTypeQLoad,
2266-
gemm0G, gemm0M, gemm0N, rewriter, gemm0TuningParams, featuresAttr,
2267-
ldsLayoutCfgNG0);
2264+
blockSize, elemTypeK, elemTypeQ, elemTypeQ, elemTypeQLoad, gemm0G,
2265+
gemm0M, gemm0N, gemm0TuningParams, featuresAttr, ldsLayoutCfgNG0);
22682266
ldsTileBufferQ =
22692267
viewBufferAs(rewriter, ldsByteBufferQ,
22702268
vectorTypeOrSelf(elemTypeQ, gemm0kpack));
22712269
}
22722270

22732271
loadAndStoreGemmInputTile(
2274-
loc, inK, kLoopIV, tid, gridCoordsGemm0, ldsByteBufferK,
2272+
rewriter, loc, inK, kLoopIV, tid, gridCoordsGemm0, ldsByteBufferK,
22752273
preAccelRegBufferK, GemmLoadTileType::Default, "m", blockSize,
2276-
elemTypeK, elemTypeKLoad, elemTypeQ, elemTypeQLoad, gemm0G, gemm0M,
2277-
gemm0N, rewriter, gemm0TuningParams, featuresAttr, ldsLayoutCfgMG0);
2274+
elemTypeK, elemTypeQ, elemTypeK, elemTypeKLoad, gemm0G, gemm0M,
2275+
gemm0N, gemm0TuningParams, featuresAttr, ldsLayoutCfgMG0);
22782276
TypedValue<MemRefType> ldsTileBufferK = viewBufferAs(
22792277
rewriter, ldsByteBufferK, vectorTypeOrSelf(elemTypeK, gemm0kpack));
22802278
// LDS barrier.
@@ -2521,12 +2519,11 @@ struct GridwiseAttentionAccelRewritePattern
25212519
}
25222520

25232521
loadAndStoreGemmInputTile(
2524-
loc, inV,
2522+
rewriter, loc, inV,
25252523
/*kIter=*/mLoopIV, tid, gridCoordsGemm1, ldsByteBufferV,
25262524
preAccelRegBufferV, GemmLoadTileType::Default, "m", blockSize,
2527-
elemTypeV, elemTypeVLoad, elemTypeV, elemTypeVLoad, gemm0G,
2528-
gemm1M, gemm1N, rewriter, gemm1TuningParams, featuresAttr,
2529-
ldsLayoutCfgMG1);
2525+
elemTypeV, elemTypeV, elemTypeV, elemTypeVLoad, gemm0G, gemm1M,
2526+
gemm1N, gemm1TuningParams, featuresAttr, ldsLayoutCfgMG1);
25302527
TypedValue<MemRefType> ldsTileBufferV =
25312528
viewBufferAs(rewriter, ldsByteBufferV,
25322529
vectorTypeOrSelf(elemTypeV, gemm1kpack));
@@ -2956,15 +2953,15 @@ struct GridwiseGemmAccelRewritePattern
29562953

29572954
// Load from global memory to LDS
29582955
loadAndStoreGemmInputTile(
2959-
loc, matB, /*kiter=*/iv, tid, gridCoords, ldsByteBufferB,
2960-
arrayBForLoad, loadType, "n", blockSize, elementTypeA,
2961-
elementTypeALoad, elementTypeB, elementTypeBLoad, G, M, N, b,
2962-
op.getParamsAttr(), featuresAttr, ldsLayoutConfigB);
2956+
b, loc, matB, /*kiter=*/iv, tid, gridCoords, ldsByteBufferB,
2957+
arrayBForLoad, loadType, "n", blockSize, elementTypeA, elementTypeB,
2958+
elementTypeB, elementTypeBLoad, G, M, N, op.getParamsAttr(),
2959+
featuresAttr, ldsLayoutConfigB);
29632960
loadAndStoreGemmInputTile(
2964-
loc, matA, /*kiter=*/iv, tid, gridCoords, ldsByteBufferA,
2965-
arrayAForLoad, loadType, "m", blockSize, elementTypeA,
2966-
elementTypeALoad, elementTypeB, elementTypeBLoad, G, M, N, b,
2967-
op.getParamsAttr(), featuresAttr, ldsLayoutConfigA);
2961+
b, loc, matA, /*kiter=*/iv, tid, gridCoords, ldsByteBufferA,
2962+
arrayAForLoad, loadType, "m", blockSize, elementTypeA, elementTypeB,
2963+
elementTypeA, elementTypeALoad, G, M, N, op.getParamsAttr(),
2964+
featuresAttr, ldsLayoutConfigA);
29682965

29692966
// Emit blockwise GEMM. This will load data from LDS (or registers) and
29702967
// compute the MMA at the same time

mlir/test/Dialect/Rock/effects.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ func.func @loadtile_doublebuffer(%arg0: memref<1x384x64xf32>, %lds: memref<4096x
552552
// expected-remark @below {{found an instance of 'write' on op operand 1, on resource '<Default>'}}
553553
// expected-remark @below {{found an instance of 'read' on op operand 1, on resource '<Default>'}}
554554
// expected-remark @below {{found an instance of 'write' on op operand 2, on resource '<Default>'}}
555-
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeALoad = f32, elementTypeB = f32, elementTypeBLoad = f32, loadType = #rock<GemmLoadTileType DoubleBuffer>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 2, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>> -> memref<16xf32, #gpu.address_space<private>>
555+
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeB = f32, elementType = f32, elementLoadType = f32, loadType = #rock<GemmLoadTileType DoubleBuffer>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 2, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>> -> memref<16xf32, #gpu.address_space<private>>
556556
}
557557
return
558558
}
@@ -566,7 +566,7 @@ func.func @loadtile_default(%arg0: memref<1x384x64xf32>, %lds: memref<4096xi8, #
566566
affine.for %arg1 = 0 to 2 {
567567
// expected-remark @below {{found an instance of 'read' on op operand 0, on resource '<Default>'}}
568568
// expected-remark @below {{found an instance of 'write' on op operand 1, on resource '<Default>'}}
569-
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeALoad = f32, elementTypeB = f32, elementTypeBLoad = f32, loadType = #rock<GemmLoadTileType Default>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>>
569+
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeB = f32, elementType = f32, elementLoadType = f32, loadType = #rock<GemmLoadTileType Default>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>>
570570
}
571571
return
572572
}
@@ -580,7 +580,7 @@ func.func @loadtile_bypasslds(%arg0: memref<1x384x64xf32>, %reg: memref<16xf32,
580580
affine.for %arg1 = 0 to 2 {
581581
// expected-remark @below {{found an instance of 'read' on op operand 0, on resource '<Default>'}}
582582
// expected-remark @below {{found an instance of 'write' on op operand 1, on resource '<Default>'}}
583-
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeALoad = f32, elementTypeB = f32, elementTypeBLoad = f32, loadType = #rock<GemmLoadTileType BypassLDS>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space<private>>
583+
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeB = f32, elementType = f32, elementLoadType = f32, loadType = #rock<GemmLoadTileType BypassLDS>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> -> memref<16xf32, #gpu.address_space<private>>
584584
}
585585
return
586586
}
@@ -596,7 +596,7 @@ func.func @loadtile_doublebuffer_directtolds(%arg0: memref<1x384x64xf32>, %lds:
596596
// expected-remark @below {{found an instance of 'write' on op operand 1, on resource '<Default>'}}
597597
// expected-remark @below {{found an instance of 'read' on op operand 1, on resource '<Default>'}}
598598
// expected-remark @below {{found an instance of 'write' on op operand 2, on resource '<Default>'}}
599-
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeALoad = f32, elementTypeB = f32, elementTypeBLoad = f32, loadType = #rock<GemmLoadTileType DirectToLDSDoubleBuffer>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 4, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>> -> memref<16xf32, #gpu.address_space<private>>
599+
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds -> %reg {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeB = f32, elementType = f32, elementLoadType = f32, loadType = #rock<GemmLoadTileType DirectToLDSDoubleBuffer>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 4, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>> -> memref<16xf32, #gpu.address_space<private>>
600600
}
601601
return
602602
}
@@ -610,7 +610,7 @@ func.func @loadtile_default_directtolds(%arg0: memref<1x384x64xf32>, %lds: memre
610610
affine.for %arg1 = 0 to 2 {
611611
// expected-remark @below {{found an instance of 'read' on op operand 0, on resource '<Default>'}}
612612
// expected-remark @below {{found an instance of 'write' on op operand 1, on resource '<Default>'}}
613-
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeALoad = f32, elementTypeB = f32, elementTypeBLoad = f32, loadType = #rock<GemmLoadTileType DirectToLDSDefault>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 3, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>>
613+
rock.blockwise_load_tile %0[%arg1, %c0, %c0, %c0, %c0] LDS -> %lds {G = 1 : i64, M = 384 : i64, N = 384 : i64, blockSize = 64 : i32, elementTypeA = f32, elementTypeB = f32, elementType = f32, elementLoadType = f32, loadType = #rock<GemmLoadTileType DirectToLDSDefault>, params = #rock.xdlops_gemm_derived_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 3, outputSwizzle = 2, forceUnroll = true>} : memref<1x64x384xf32> LDS -> memref<4096xi8, #gpu.address_space<workgroup>>
614614
}
615615
return
616616
}

0 commit comments

Comments
 (0)