Skip to content

Commit 7c2d8ce

Browse files
[7.0][BACKPORT] Find first gemm index after fusing linalg.generic ops (#1923)
Find first gemm index after fusing linalg.generic ops (#1922) Fusions can generate multiple linalg.generic ops and firstGemmIdx might not match between block inputs and linalg.generic op after linalg.generic ops have been fused. This pass traces the fused linalg.generic op back to the firstGemmIdx of the block inputs. --------- Co-authored-by: Umang Yadav <[email protected]>
1 parent 3f3975a commit 7c2d8ce

File tree

13 files changed

+298
-15
lines changed

13 files changed

+298
-15
lines changed

mlir/include/mlir/Conversion/RocMLIRPasses.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,8 @@ def GPUToMIGraphXPass : Pass<"gpu-to-migraphx", "::mlir::func::FuncOp"> {
9494
Pass that converts func operations with gpu.launch to MIGraphX operation.
9595
}];
9696

97-
let dependentDialects = [
98-
"migraphx::MIGraphXDialect",
99-
"func::FuncDialect",
100-
"gpu::GPUDialect"
101-
];
97+
let dependentDialects = ["migraphx::MIGraphXDialect", "func::FuncDialect",
98+
"gpu::GPUDialect"];
10299
}
103100

104101
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,16 @@ def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
243243
/*methodBody=*/"",
244244
/*defaultImplementation=*/ ""
245245
>,
246+
InterfaceMethod<
247+
/*desc=*/[{
248+
Return the region corresponding to the fusion between first and second GEMM.
249+
}],
250+
/*retType=*/"Region&",
251+
/*methodName=*/"getPreSecondGemmRegion",
252+
/*args=*/(ins),
253+
/*methodBody=*/"",
254+
/*defaultImplementation=*/ ""
255+
>,
246256

247257
// TODO: more methods here as needed
248258
];

mlir/include/mlir/Dialect/Rock/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ namespace rock {
4747
#define GEN_PASS_DECL_ROCKSHUFFLEGEMMFORREDUCTIONS
4848
#define GEN_PASS_DECL_ROCKGEMMLINALGSPLITKNORMALIZATIONPASS
4949
#define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS
50+
#define GEN_PASS_DECL_ROCKFINDFIRSTGEMMINDEXPASS
5051

5152
#define GEN_PASS_REGISTRATION
5253
#include "mlir/Dialect/Rock/Passes.h.inc"

mlir/include/mlir/Dialect/Rock/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,13 @@ def RockSortDimensionsMemoryLayoutPass : Pass<"rock-sort-dimensions-memory-layou
178178
let dependentDialects = ["rock::RockDialect", "func::FuncDialect", "arith::ArithDialect", "linalg::LinalgDialect"];
179179
}
180180

181+
def RockFindFirstGemmIndexPass
182+
: Pass<"rock-find-first-gemm-index", "::mlir::func::FuncOp"> {
183+
let summary = "Fusions input arguments and linalg.generic input arguments "
184+
"don't have to match. This pass sets the first gemm index of "
185+
"the linalg.generic correctly.";
186+
let dependentDialects = ["rock::RockDialect", "func::FuncDialect",
187+
"linalg::LinalgDialect"];
188+
}
189+
181190
#endif // MLIR_DIALECT_ROCK_PASSES

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,10 @@ void GemmElementwiseGemmOp::setFirstGemmIndex(uint32_t index) {
21252125
setFirstGemmIdx(index);
21262126
}
21272127

2128+
Region &GemmElementwiseGemmOp::getPreSecondGemmRegion() {
2129+
return getPreSecondGemmBody();
2130+
}
2131+
21282132
GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
21292133
ShapedType typeA = getA().getType(), typeB = getB().getType(),
21302134
typeC = getC().getType();
@@ -2297,6 +2301,10 @@ void ConvElementwiseGemmOp::setFirstGemmIndex(uint32_t index) {
22972301
setFirstGemmIdx(index);
22982302
}
22992303

2304+
Region &ConvElementwiseGemmOp::getPreSecondGemmRegion() {
2305+
return getPreSecondGemmBody();
2306+
}
2307+
23002308
GemmGemmSize ConvElementwiseGemmOp::getGemmGemmSize() {
23012309
auto strideVal = extractFromIntegerArrayAttr<int64_t>(getStrides());
23022310
auto dilationVal = extractFromIntegerArrayAttr<int64_t>(getDilations());
@@ -2372,6 +2380,8 @@ uint32_t AttentionOp::getFirstGemmIndex() { return getFirstGemmIdx(); }
23722380

23732381
void AttentionOp::setFirstGemmIndex(uint32_t index) { setFirstGemmIdx(index); }
23742382

2383+
Region &AttentionOp::getPreSecondGemmRegion() { return getPreSoftmaxBody(); }
2384+
23752385
GemmGemmSize AttentionOp::getGemmGemmSize() {
23762386
ShapedType typeA = getQueries().getType(), typeB = getKeys().getType(),
23772387
typeC = getValues().getType();

mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ void rock::buildBufferizePipeline(OpPassManager &pm,
139139
// Sort dimensions according to the underlying memory layout strides
140140
if (!noRock) {
141141
auto &funcPm4 = pm.nest<func::FuncOp>();
142+
funcPm4.addPass(createRockFindFirstGemmIndexPass());
142143
funcPm4.addPass(createRockSortDimensionsMemoryLayoutPass());
143144
}
144145
}

mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms
2929
ShuffleGemmForReductions.cpp
3030
GemmLinalgSplitkNormalizationPass.cpp
3131
SortDimensionsMemoryLayout.cpp
32+
FindFirstGemmIndex.cpp
3233

3334
ADDITIONAL_HEADER_DIRS
3435
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Rock
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
//===- FindFirstGemmIndex.cpp -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fusions can generate multiple linalg.generic ops and firstGemmIdx might not
10+
// match between block inputs and linalg.generic op after linalg.generic ops
11+
// have been fused. This pass traces the fused linalg.generic op back to the
12+
// firstGemmIdx of the block inputs.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
18+
#include "mlir/Dialect/Rock/IR/Rock.h"
19+
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
20+
#include "mlir/Dialect/Rock/Passes.h"
21+
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
22+
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
23+
#include "mlir/IR/BuiltinAttributes.h"
24+
#include "mlir/IR/Value.h"
25+
#include "mlir/Support/LogicalResult.h"
26+
#include "llvm/ADT/STLExtras.h"
27+
#include <cstdint>
28+
29+
namespace mlir {
30+
namespace rock {
31+
#define GEN_PASS_DEF_ROCKFINDFIRSTGEMMINDEXPASS
32+
#include "mlir/Dialect/Rock/Passes.h.inc"
33+
} // namespace rock
34+
} // namespace mlir
35+
36+
#define DEBUG_TYPE "rock-find-first-gemm-index"
37+
38+
using namespace mlir;
39+
40+
namespace {
41+
struct RockFindFirstGemmIndexPass
42+
: public rock::impl::RockFindFirstGemmIndexPassBase<
43+
RockFindFirstGemmIndexPass> {
44+
void runOnOperation() override;
45+
};
46+
} // end anonymous namespace
47+
48+
static bool canTraceToFirstGemmArg(Value input, Value firstGemmArg) {
49+
// trace back to block arguments (through ViewLike ops)
50+
FailureOr<Value> maybeBlockArg = rock::findBlockArgument(input);
51+
if (failed(maybeBlockArg))
52+
return false;
53+
54+
// If the input is a block argument, check if it matches the first gemm
55+
// argument.
56+
return maybeBlockArg.value() == firstGemmArg;
57+
}
58+
59+
static LogicalResult
60+
reassignFirstGemmIndex(func::FuncOp &func,
61+
rock::RockGemmGemmWrapperInterface gemmGemmOp) {
62+
// Get the first gemm index from the gemmGemmOp.
63+
uint32_t firstGemmIndex = gemmGemmOp.getFirstGemmIndex();
64+
65+
// get linalg.generic ops in the preSecondGemmRegion
66+
SmallVector<linalg::GenericOp> genOps;
67+
gemmGemmOp.getPreSecondGemmRegion().walk(
68+
[&genOps](linalg::GenericOp genOp) { genOps.push_back(genOp); });
69+
70+
// no fusion, nothing to do
71+
if (genOps.empty())
72+
return success();
73+
74+
if (genOps.size() != 1)
75+
return gemmGemmOp.emitError(
76+
"More than one linalg.generic operation found, expected only one.");
77+
78+
linalg::GenericOp genOp = genOps[0];
79+
assert(firstGemmIndex <
80+
gemmGemmOp.getPreSecondGemmRegion().getNumArguments());
81+
Value firstGemmArg =
82+
gemmGemmOp.getPreSecondGemmRegion().getArgument(firstGemmIndex);
83+
84+
// try to trace args of linalg.generic op back to the first gemm argument
85+
int64_t newFirstGemmIndex = -1;
86+
for (auto [index, input] : llvm::enumerate(genOp.getInputs())) {
87+
if (canTraceToFirstGemmArg(input, firstGemmArg)) {
88+
// If the input can be traced back to the first gemm argument, we found
89+
// the new index.
90+
newFirstGemmIndex = index;
91+
break;
92+
}
93+
}
94+
if (newFirstGemmIndex == -1) {
95+
return gemmGemmOp.emitError(
96+
"Could not find a matching input for the first gemm index.");
97+
}
98+
99+
// Set the new first gemm index in the gemmGemmOp.
100+
gemmGemmOp.setFirstGemmIndex(static_cast<uint32_t>(newFirstGemmIndex));
101+
return success();
102+
}
103+
104+
void RockFindFirstGemmIndexPass::runOnOperation() {
105+
auto func = getOperation();
106+
// Only run this pass on GPU kernel functions.
107+
if (!func->hasAttr("kernel"))
108+
return;
109+
110+
// find gemm+gemm like operations with fusion
111+
SmallVector<rock::RockGemmGemmWrapperInterface> gemmGemmOps;
112+
func.walk([&gemmGemmOps](rock::RockGemmGemmWrapperInterface gemmGemmOp) {
113+
gemmGemmOps.push_back(gemmGemmOp);
114+
});
115+
116+
// no gemm+gemm like operations found, nothing to do
117+
if (gemmGemmOps.empty())
118+
return;
119+
120+
if (gemmGemmOps.size() != 1) {
121+
func.emitError(
122+
"More than one gemm+gemm like operation found, expected only one.");
123+
return signalPassFailure();
124+
}
125+
126+
if (failed(reassignFirstGemmIndex(func, gemmGemmOps[0]))) {
127+
return signalPassFailure();
128+
}
129+
}

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353
#include "GridLayoutEmitter.h"
5454
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
55+
#include "llvm/ADT/STLExtras.h"
5556
#include "llvm/Support/Debug.h"
5657
#include "llvm/Support/FormatVariadic.h"
5758
#include <optional>
@@ -1586,29 +1587,53 @@ struct GridwiseAttentionAccelRewritePattern
15861587
ArrayAttr linalgGridSubTileMaps = gemm0OutViews.gridSubTile;
15871588
ArrayAttr gemmOutToLinalgMaps =
15881589
invertTransforms(rewriter, loc, linalgToGemmOutMaps);
1590+
1591+
if (!gemmOutToLinalgMaps) {
1592+
res = rewriter.notifyMatchFailure(
1593+
genOp, "we can't invert linalg input to gemmOutput maps");
1594+
return;
1595+
}
1596+
15891597
if (!gemmOutToLinalgMaps.empty()) {
15901598
linalgGridSubTileMaps = prependUpperViews(
15911599
rewriter, linalgGridSubTileMaps, gemmOutToLinalgMaps);
15921600
}
15931601

1594-
for (auto [idx, otherInput] :
1595-
llvm::enumerate(op.getPreSoftmaxElemWiseInputs())) {
1596-
if (idx >= op.getFirstGemmIdx())
1597-
idx++;
1602+
for (auto [idx, genOpInput] : llvm::enumerate(genOp.getInputs())) {
1603+
if (idx == op.getFirstGemmIdx())
1604+
continue;
1605+
1606+
Value otherInput;
1607+
ArrayAttr linalgToOtherInputMaps;
1608+
std::tie(otherInput, linalgToOtherInputMaps, std::ignore) =
1609+
untransform(rewriter, genOpInput);
1610+
15981611
MemRefType otherInputBufType = cast<MemRefType>(otherInput.getType());
15991612
MemRefType tileBufType = MemRefType::get(
16001613
srcBufType.getShape(), otherInputBufType.getElementType(),
16011614
AffineMap{}, privateMemoryAddressSpace);
16021615
auto tileBuffer = rewriter.create<rock::GpuAllocOp>(loc, tileBufType);
1603-
auto genOpInput = genOp.getInputs()[idx];
1604-
ArrayAttr linalgToOtherInputMaps;
1605-
std::tie(std::ignore, linalgToOtherInputMaps, std::ignore) =
1606-
untransform(rewriter, genOpInput);
1616+
16071617
ArrayAttr gemmOutToOtherInputMaps = linalgGridSubTileMaps;
16081618
if (!linalgToOtherInputMaps.empty()) {
16091619
gemmOutToOtherInputMaps = prependUpperViews(
16101620
rewriter, linalgGridSubTileMaps, linalgToOtherInputMaps);
16111621
}
1622+
// If other input is a block argument of the attention op fusion
1623+
if (auto blockArg = dyn_cast<BlockArgument>(otherInput)) {
1624+
// trace it back to block input
1625+
Block &block = op.getPreSoftmaxBody().getBlocks().front();
1626+
if (blockArg.getOwner() == &block) {
1627+
int64_t blockArgNum = blockArg.getArgNumber();
1628+
assert(blockArgNum != op.getFirstGemmIdx());
1629+
1630+
// if the gemm index is smaller, we need to substract one from the
1631+
// index
1632+
if (blockArgNum > op.getFirstGemmIdx())
1633+
--blockArgNum;
1634+
otherInput = op.getPreSoftmaxElemWiseInputs()[blockArgNum];
1635+
}
1636+
}
16121637
rewriter.create<ThreadwiseReadIntoOp>(
16131638
loc, otherInput, tileBuffer, gemmOutToOtherInputMaps,
16141639
ValueRange{gridCoords.g_block, gridCoords.m_block,

mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,7 @@ ArrayAttr mlir::rock::invertTransforms(OpBuilder &b, Location loc,
17101710
auto trMap = cast<TransformMapAttr>(tr);
17111711
TransformMapAttr invertedTrMap = invertTransformMap(b, trMap, loc);
17121712
if (!invertedTrMap)
1713-
return {};
1713+
return nullptr;
17141714
invertedTrs.push_back(invertedTrMap);
17151715
}
17161716
return b.getArrayAttr(invertedTrs);

0 commit comments

Comments
 (0)