Skip to content

Commit 6a72b2b

Browse files
authored
Refactor BlockwiseGemmAccelOp to take registers as well (#1926)
Refactor BlockwiseGemmAccelOp to take registers
1 parent 6b99232 commit 6a72b2b

14 files changed

+301
-257
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,14 +1367,16 @@ def Rock_BlockwiseGemmAccelOp
13671367
Arguments<(ins MemRefOf<LdsBufferTypes>:$matrixA,
13681368
MemRefOf<LdsBufferTypes>:$matrixB, I32Attr:$inMPerThread,
13691369
I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK,
1370-
MemRefOf<AccelArgTypes>:$bufferA, MemRefOf<AccelArgTypes>:$bufferB,
1371-
MemRefOf<AccelResTypes>:$matrixC,
1370+
UnitAttr:$loadAfromLDS, UnitAttr:$loadBfromLDS,
1371+
UnitAttr:$splitKAcrossThreadsFirstA,
1372+
UnitAttr:$splitKAcrossThreadsFirstB, MemRefOf<AccelArgTypes>:$bufferA,
1373+
MemRefOf<AccelArgTypes>:$bufferB, MemRefOf<AccelResTypes>:$matrixC,
13721374
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
13731375
RockAccelTuningParamAttrInterface:$params)> {
13741376
let summary = "Blockwise GEMM accelerated version";
13751377
let description = [{
1376-
The `rock.block_gemm_v2` op does GEMM at workgroup (block) level.
1377-
- Matrix A and Matrix B shall reside on LDS (naive tensor).
1378+
The `rock.blockwise_gemm_accel` op does GEMM at workgroup (block) level.
1379+
- Matrix A and Matrix B shall reside on LDS or registers (depending on loadAfromLDS and loadBfromLDS).
13781380
- Matrix C shall be vectors.
13791381

13801382
The elements of matrices A and B should be vectors of length kpack, or

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ struct BlockwiseGemmAccelRewritePattern
406406
int64_t kpackPerBlock = tuningParams.getKpackPerBlock();
407407
int64_t mPerWave = tuningParams.getMPerWave();
408408
int64_t nPerWave = tuningParams.getNPerWave();
409+
bool loadAFromLDS = adaptor.getLoadAfromLDS();
410+
bool loadBFromLDS = adaptor.getLoadBfromLDS();
409411

410412
Type bufferElemTypeA =
411413
cast<MemRefType>(adaptor.getMatrixA().getType()).getElementType();
@@ -445,6 +447,9 @@ struct BlockwiseGemmAccelRewritePattern
445447
<< "nRepeat: " << nRepeats << "\n"
446448
<< "kBasePerThread: " << kBasePerThread << "\n"
447449
<< "kpackPerBlock: " << kpackPerBlock << "\n"
450+
<< "loadAFromLDS: " << loadAFromLDS << "\n"
451+
<< "loadBFromLDS: " << loadBFromLDS << "\n"
452+
<< "rotateMWithK: " << op.getRotateMWithK() << "\n"
448453
<< "bufferA type: " << adaptor.getBufferA().getType() << "\n"
449454
<< "bufferB type: " << adaptor.getBufferB().getType() << "\n");
450455

@@ -463,44 +468,75 @@ struct BlockwiseGemmAccelRewritePattern
463468
// considered a temporary hack until we have a proper way of "searching"
464469
// through different schedules (either heuristically or automatically)
465470

466-
Value wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad(
467-
b, loc, op.getMatrixA(), op.getBlockSize(), op.getInMPerThread(), "m",
468-
op.getRotateMWithK());
469-
Value wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad(
470-
b, loc, op.getMatrixB(), op.getBlockSize(), op.getInNPerThread(), "n",
471-
op.getRotateNWithK());
471+
Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB;
472+
if (loadAFromLDS) {
473+
wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad(
474+
b, loc, op.getMatrixA(), op.getBlockSize(), op.getInMPerThread(), "m",
475+
op.getRotateMWithK(), op.getSplitKAcrossThreadsFirstA());
476+
}
477+
if (loadBFromLDS) {
478+
wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad(
479+
b, loc, op.getMatrixB(), op.getBlockSize(), op.getInNPerThread(), "n",
480+
op.getRotateNWithK(), op.getSplitKAcrossThreadsFirstA());
481+
}
472482

473483
auto mLoop = b.create<affine::AffineForOp>(loc, 0, mRepeats);
474484
{
475485
OpBuilder::InsertionGuard guard(b);
476486
b.setInsertionPointToStart(mLoop.getBody());
477487
Value i = mLoop.getInductionVar();
478488

479-
// regsA = read A from LDS
480-
b.create<ThreadwiseReadIntoOp>(
481-
loc, wrappedLDSBufferForLoadA, op.getBufferA(), b.getArrayAttr({}),
482-
ValueRange{tid, i}, /*forceUnroll=*/true, /*useIndexDiffs=*/true);
489+
Value bufferA = adaptor.getBufferA();
490+
if (loadAFromLDS) {
491+
// regsA = read A from LDS
492+
b.create<ThreadwiseReadIntoOp>(
493+
loc, wrappedLDSBufferForLoadA, bufferA, b.getArrayAttr({}),
494+
ValueRange{tid, i}, /*forceUnroll=*/true, /*useIndexDiffs=*/true);
495+
} else {
496+
if (cast<ShapedType>(bufferA.getType()).getRank() == 1) {
497+
BottomUpTMBuilder regsBuilder(b, {"mk"}, {mRepeats * kBasePerThread},
498+
loc);
499+
regsBuilder.unmerge({"iidx", "k"}, {0, 1}, "mk",
500+
{mRepeats, kBasePerThread});
501+
bufferA =
502+
rock::transform(b, bufferA, b.getArrayAttr({regsBuilder.get()}));
503+
}
504+
bufferA = rock::createSliceOfFirstDim(b, loc, bufferA, i);
505+
}
506+
Value viewA =
507+
accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, bufferA);
483508

484509
auto nLoop = b.create<affine::AffineForOp>(loc, 0, nRepeats);
485510
{
486511
OpBuilder::InsertionGuard guard(b);
487512
b.setInsertionPointToStart(nLoop.getBody());
488513
Value j = nLoop.getInductionVar();
489514

490-
// regsB = read B from LDS
491-
b.create<ThreadwiseReadIntoOp>(
492-
loc, wrappedLDSBufferForLoadB, op.getBufferB(), b.getArrayAttr({}),
493-
ValueRange{tid, j}, /*forceUnroll=*/true, /*useIndexDiffs=*/true);
515+
Value bufferB = adaptor.getBufferB();
516+
if (loadBFromLDS) {
517+
// regsB = read B from LDS
518+
b.create<ThreadwiseReadIntoOp>(
519+
loc, wrappedLDSBufferForLoadB, bufferB, b.getArrayAttr({}),
520+
ValueRange{tid, j}, /*forceUnroll=*/true, /*useIndexDiffs=*/true);
521+
} else {
522+
if (cast<ShapedType>(bufferB.getType()).getRank() == 1) {
523+
BottomUpTMBuilder regsBBuilder(b, {"nk"},
524+
{nRepeats * kBasePerThread}, loc);
525+
regsBBuilder.unmerge({"jidx", "k"}, {0, 1}, "nk",
526+
{nRepeats, kBasePerThread});
527+
bufferB = rock::transform(b, bufferB,
528+
b.getArrayAttr({regsBBuilder.get()}));
529+
}
530+
bufferB = rock::createSliceOfFirstDim(b, loc, bufferB, j);
531+
}
532+
Value viewB =
533+
accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, bufferB);
494534

495535
// regsC += regsA * regsB
496536
auto kLoop = b.create<affine::AffineForOp>(loc, 0, kBasePerThread);
497537
{
498538
OpBuilder::InsertionGuard guard(b);
499539
b.setInsertionPointToStart(kLoop.getBody());
500-
Value viewA = accelEmitterPtr->generateThreadwiseViewBufferA(
501-
b, loc, adaptor.getBufferA());
502-
Value viewB = accelEmitterPtr->generateThreadwiseViewBufferB(
503-
b, loc, adaptor.getBufferB());
504540
Value viewC = accelEmitterPtr->generateThreadwiseViewBufferC(
505541
b, loc, adaptor.getMatrixC());
506542
Value k = kLoop.getInductionVar();

0 commit comments

Comments
 (0)