Skip to content

Commit 73be116

Browse files
[LLVMGPU] Pass to decompose horizontally fused GEMMs before layout configuration. (#19924)
Currently this is done using a pass that uses a recognizer. It might be more ergnomic to just add an operation for the horizontally fused gemm operation. Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 5767be3 commit 73be116

File tree

11 files changed

+443
-16
lines changed

11 files changed

+443
-16
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ iree_compiler_cc_library(
5050
name = "CommonGPUPasses",
5151
srcs = [
5252
"AMDGPUDistributeContract.cpp",
53+
"DecomposeHorizontallyFusedGemms.cpp",
5354
"ExpandGPUOps.cpp",
5455
"GPUApplyTilingLevel.cpp",
5556
"GPUCheckResourceUsage.cpp",
@@ -107,6 +108,7 @@ iree_compiler_cc_library(
107108
"//compiler/src/iree/compiler/Codegen/Utils",
108109
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
109110
"//compiler/src/iree/compiler/Dialect/HAL/IR",
111+
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
110112
"//compiler/src/iree/compiler/Utils",
111113
"@llvm-project//llvm:Support",
112114
"@llvm-project//mlir:AMDGPUDialect",

compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ iree_cc_library(
4848
"Passes.h"
4949
SRCS
5050
"AMDGPUDistributeContract.cpp"
51+
"DecomposeHorizontallyFusedGemms.cpp"
5152
"ExpandGPUOps.cpp"
5253
"GPUApplyTilingLevel.cpp"
5354
"GPUCheckResourceUsage.cpp"
@@ -141,6 +142,7 @@ iree_cc_library(
141142
iree::compiler::Codegen::Utils
142143
iree::compiler::Codegen::Utils::VectorOpUtils
143144
iree::compiler::Dialect::HAL::IR
145+
iree::compiler::Dialect::LinalgExt::Utils
144146
iree::compiler::Utils
145147
PUBLIC
146148
)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9+
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
10+
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
11+
#include "mlir/Analysis/SliceAnalysis.h"
12+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
13+
#include "mlir/Interfaces/FunctionInterfaces.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
namespace mlir::iree_compiler {
17+
18+
#define GEN_PASS_DEF_DECOMPOSEHORIZONTALLYFUSEDGEMMSPASS
19+
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
20+
21+
namespace {
22+
23+
struct DecomposeHorizontallyFusedGemmsPass final
24+
: impl::DecomposeHorizontallyFusedGemmsPassBase<
25+
DecomposeHorizontallyFusedGemmsPass> {
26+
void runOnOperation() override;
27+
};
28+
} // namespace
29+
30+
//===---------------------------------------------------------------------===//
31+
// Decompose horizontally fused gemm operations
32+
// TODO: Eventually drop this if we end up creating an operation for the
33+
// horizontally fused contractions.
34+
//===---------------------------------------------------------------------===//
35+
36+
static LogicalResult captureUsedOperationsAndBlockArguements(
37+
linalg::LinalgOp linalgOp, SetVector<int64_t> &usedInputs,
38+
SetVector<Operation *> &usedOperations, int64_t resultNumber) {
39+
BackwardSliceOptions options;
40+
options.inclusive = true;
41+
options.filter = [&](Operation *op) -> bool {
42+
return op->getBlock() == linalgOp.getBlock();
43+
};
44+
45+
auto yieldOp = cast<linalg::YieldOp>(linalgOp.getBlock()->getTerminator());
46+
Value result = yieldOp.getOperand(resultNumber);
47+
48+
getBackwardSlice(result, &usedOperations, options);
49+
50+
// Get all block arguments used by the operations. If any of the arguments
51+
// used is a dpsInit argument other than resultNumber, return failure.
52+
for (Operation *op : usedOperations) {
53+
for (Value operand : op->getOperands()) {
54+
if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
55+
if (blockArg.getOwner() != linalgOp.getBlock()) {
56+
continue;
57+
}
58+
59+
int64_t argNumber = blockArg.getArgNumber();
60+
if (argNumber >= linalgOp.getNumDpsInputs() &&
61+
argNumber - linalgOp.getNumDpsInputs() != resultNumber) {
62+
return failure();
63+
}
64+
65+
if (argNumber < linalgOp.getNumDpsInputs()) {
66+
usedInputs.insert(argNumber);
67+
}
68+
}
69+
}
70+
}
71+
72+
return success();
73+
}
74+
75+
// Since the `promotedOperands` changes that needs to be modified
76+
// and transfered over to the decomposed ops.
77+
static IREE::GPU::LoweringConfigAttr
78+
getModifiedLoweringConfigForDecomposedGemmOp(
79+
RewriterBase &rewriter, IREE::GPU::LoweringConfigAttr origAttr,
80+
ArrayRef<unsigned> keptOperands) {
81+
std::optional<SmallVector<int64_t>> promotedOperandsList =
82+
IREE::GPU::getPromotedOperandList(origAttr);
83+
if (!promotedOperandsList) {
84+
return origAttr;
85+
}
86+
87+
llvm::SmallDenseSet<int64_t> promotedOperandsSet(
88+
promotedOperandsList->begin(), promotedOperandsList->end());
89+
SmallVector<int64_t> newPromotedOperands;
90+
for (auto [index, origOperandNum] : llvm::enumerate(keptOperands)) {
91+
if (promotedOperandsSet.contains(origOperandNum)) {
92+
newPromotedOperands.push_back(index);
93+
}
94+
}
95+
return setPromotedOperandsList(rewriter.getContext(), origAttr,
96+
newPromotedOperands);
97+
}
98+
99+
static LogicalResult
100+
decomposeHorizontallyFusedGemmOperations(RewriterBase &rewriter,
101+
linalg::LinalgOp linalgOp) {
102+
assert(IREE::LinalgExt::isaHorizontallyFusedContraction(linalgOp) &&
103+
"expected op that is a horizontally fused contraction");
104+
105+
OpBuilder::InsertionGuard g(rewriter);
106+
rewriter.setInsertionPoint(linalgOp);
107+
// Create num_results linalg.generics, each producing a single result (and
108+
// relying on canonicalizations to simplify).
109+
for (int64_t resultNumber : llvm::seq<int64_t>(linalgOp->getNumResults())) {
110+
rewriter.setInsertionPoint(linalgOp);
111+
112+
auto yieldOp = cast<linalg::YieldOp>(linalgOp.getBlock()->getTerminator());
113+
Value result = yieldOp.getOperand(resultNumber);
114+
115+
// Get all operations required to produce this result.
116+
SetVector<Operation *> usedOperations;
117+
SetVector<int64_t> usedInputs;
118+
if (failed(captureUsedOperationsAndBlockArguements(
119+
linalgOp, usedInputs, usedOperations, resultNumber))) {
120+
return failure();
121+
}
122+
123+
// Create a new linalg.generic operation for this result.
124+
SmallVector<OpOperand *> inputs = llvm::map_to_vector(
125+
usedInputs, [&](int64_t x) { return linalgOp.getDpsInputOperand(x); });
126+
SmallVector<OpOperand *> inits = {linalgOp.getDpsInitOperand(resultNumber)};
127+
128+
SmallVector<AffineMap> indexingMaps =
129+
llvm::map_to_vector(usedInputs, [&](int64_t x) {
130+
return linalgOp.getIndexingMapsArray()[x];
131+
});
132+
indexingMaps.push_back(linalgOp.getIndexingMapMatchingResult(
133+
linalgOp->getOpResult(resultNumber)));
134+
llvm::SmallBitVector unusedDims = getUnusedDimsBitVector(indexingMaps);
135+
indexingMaps = compressUnusedDims(indexingMaps);
136+
137+
SmallVector<utils::IteratorType> iteratorTypes;
138+
for (int64_t i : llvm::seq<int64_t>(linalgOp.getNumLoops())) {
139+
if (!unusedDims.test(i)) {
140+
iteratorTypes.push_back(linalgOp.getIteratorTypesArray()[i]);
141+
}
142+
}
143+
144+
SmallVector<Value> inputVals = llvm::map_to_vector(
145+
inputs, [](OpOperand *operand) { return operand->get(); });
146+
SmallVector<Value> initVals = llvm::map_to_vector(
147+
inits, [](OpOperand *operand) { return operand->get(); });
148+
auto newOp = rewriter.create<linalg::GenericOp>(
149+
linalgOp.getLoc(), TypeRange{inits[0]->get().getType()}, inputVals,
150+
initVals, indexingMaps, iteratorTypes,
151+
[&](OpBuilder &b, Location loc, ValueRange blockArgs) {
152+
Block *oldBody = linalgOp.getBlock();
153+
usedInputs.insert(resultNumber + linalgOp.getNumDpsInputs());
154+
155+
IRMapping regionMapping;
156+
157+
for (auto [oldBlockArgNum, newBlockArg] :
158+
llvm::zip_equal(usedInputs, blockArgs)) {
159+
regionMapping.map(oldBody->getArgument(oldBlockArgNum),
160+
newBlockArg);
161+
}
162+
163+
for (Operation *usedOperation : usedOperations) {
164+
b.clone(*usedOperation, regionMapping);
165+
}
166+
167+
b.create<linalg::YieldOp>(loc, regionMapping.lookup(result));
168+
});
169+
170+
// If on decomposition any dims are unused propagating lowering config isnt
171+
// well defined. So propagate lowering config only when no dim is unused.
172+
if (unusedDims.none()) {
173+
IREE::GPU::LoweringConfigAttr loweringConfigAttr =
174+
getLoweringConfig<IREE::GPU::LoweringConfigAttr>(linalgOp);
175+
if (loweringConfigAttr && getPromotedOperandList(loweringConfigAttr)) {
176+
SmallVector<unsigned> operandNums =
177+
llvm::map_to_vector(inputs, [](OpOperand *operand) {
178+
return operand->getOperandNumber();
179+
});
180+
auto range = llvm::map_range(inits, [](OpOperand *operand) {
181+
return operand->getOperandNumber();
182+
});
183+
operandNums.append(range.begin(), range.end());
184+
IREE::GPU::LoweringConfigAttr newGPUAttr =
185+
getModifiedLoweringConfigForDecomposedGemmOp(
186+
rewriter, loweringConfigAttr, operandNums);
187+
setLoweringConfig(newOp, newGPUAttr);
188+
}
189+
}
190+
191+
rewriter.replaceAllUsesWith(linalgOp->getResult(resultNumber),
192+
newOp.getResult(0));
193+
}
194+
195+
rewriter.eraseOp(linalgOp);
196+
return success();
197+
}
198+
199+
void DecomposeHorizontallyFusedGemmsPass::runOnOperation() {
200+
auto funcOp = getOperation();
201+
IRRewriter rewriter(&getContext());
202+
SmallVector<linalg::LinalgOp> horizontallyFusedOps;
203+
funcOp.walk([&](linalg::LinalgOp linalgOp) {
204+
if (IREE::LinalgExt::isaHorizontallyFusedContraction(linalgOp)) {
205+
horizontallyFusedOps.push_back(linalgOp);
206+
}
207+
});
208+
209+
for (auto linalgOp : llvm::make_early_inc_range(horizontallyFusedOps)) {
210+
if (failed(decomposeHorizontallyFusedGemmOperations(rewriter, linalgOp))) {
211+
return signalPassFailure();
212+
}
213+
}
214+
}
215+
216+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ include "mlir/Pass/PassBase.td"
1313
// Common Passes used for GPU-like backends (keep alphabetical)
1414
//===---------------------------------------------------------------------===//
1515

16+
def DecomposeHorizontallyFusedGemmsPass :
17+
InterfacePass<"iree-codegen-gpu-decompose-horizontally-fused-gemms",
18+
"mlir::FunctionOpInterface"> {
19+
let summary =
20+
"Decomposes a horizontally fused GEMM back into its constituent GEMMs";
21+
let dependentDialects = [
22+
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
23+
"::mlir::linalg::LinalgDialect",
24+
];
25+
}
26+
27+
1628
def GPUCheckResourceUsagePass :
1729
InterfacePass<"iree-codegen-gpu-check-resource-usage", "mlir::FunctionOpInterface"> {
1830
let summary = "Checks GPU specific resource usage constraints like shared memory limits";

compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
name = "lit",
1919
srcs = enforce_glob(
2020
[
21+
"decompose_horizontally_fused_gemms.mlir",
2122
"gpu_apply_derived_thread_config.mlir",
2223
"gpu_apply_tiling_level.mlir",
2324
"gpu_check_resource_usage.mlir",

compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ iree_lit_test_suite(
1414
NAME
1515
lit
1616
SRCS
17+
"decompose_horizontally_fused_gemms.mlir"
1718
"gpu_apply_derived_thread_config.mlir"
1819
"gpu_apply_tiling_level.mlir"
1920
"gpu_check_resource_usage.mlir"

0 commit comments

Comments
 (0)