Skip to content

Commit 9731fed

Browse files
Pass to block dynamic dimensions of operands of iree_linalg_ext.attention. (#18874)
The use of `IntegerRangeAnalysis` and `IntegerDivisibilityAnalysis` gives range and divisibility information for constants passed to the dispatch. This can be used to infer the range and divisibility information for all tensor values in the dispatch. This PR adds an analysis to do this. This analysis is then used to expand the dimensions of operands of the attention operation that are dynamic, but are known to be divisible by a compile-time static value. This gets the operations into a form that can be compiled by the AMDGPU backend and target the mfma intrinsics. Signed-off-by: MaheshRavishankar <[email protected]> --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 03c744e commit 9731fed

File tree

12 files changed

+752
-1
lines changed

12 files changed

+752
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ iree_compiler_cc_library(
8686
name = "Common",
8787
srcs = [
8888
"AddFastMathFlags.cpp",
89+
"BlockDynamicDimensions.cpp",
8990
"BubbleUpOrdinalOps.cpp",
9091
"BufferizationAnalysis.cpp",
9192
"BufferizeCopyOnlyDispatchesPass.cpp",
@@ -137,6 +138,7 @@ iree_compiler_cc_library(
137138
"RemoveSingleIterationLoop.cpp",
138139
"ReplaceSlowMinMaxOps.cpp",
139140
"SplitFullPartialTransferPass.cpp",
141+
"TensorDynamicDimAnalysis.cpp",
140142
"TensorToVectorVectorizePad.cpp",
141143
"TestExecutablePreprocessing.cpp",
142144
"TestPartitionableLoopsInterface.cpp",
@@ -155,6 +157,7 @@ iree_compiler_cc_library(
155157
"ExtractAddressComputation.h",
156158
"PassUtils.h",
157159
"Passes.h",
160+
"TensorDynamicDimAnalysis.h",
158161
"TileSizeSelection.h",
159162
"Transforms.h",
160163
"UserConfig.h",
@@ -176,6 +179,7 @@ iree_compiler_cc_library(
176179
"//compiler/src/iree/compiler/Dialect/HAL/IR",
177180
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
178181
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
182+
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
179183
"//compiler/src/iree/compiler/Dialect/Util/IR",
180184
"//compiler/src/iree/compiler/Utils",
181185
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
@@ -191,6 +195,7 @@ iree_compiler_cc_library(
191195
"@llvm-project//mlir:BufferizationDialect",
192196
"@llvm-project//mlir:BufferizationInterfaces",
193197
"@llvm-project//mlir:BufferizationTransforms",
198+
"@llvm-project//mlir:DestinationStyleOpInterface",
194199
"@llvm-project//mlir:DialectUtils",
195200
"@llvm-project//mlir:FuncDialect",
196201
"@llvm-project//mlir:FuncTransforms",
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
// Copyright 2024 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h"
9+
#include "iree/compiler/Codegen/Transforms/Transforms.h"
10+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
11+
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
12+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
13+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
14+
#include "mlir/Dialect/Arith/Utils/Utils.h"
15+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions"
20+
21+
namespace mlir::iree_compiler {
22+
23+
#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS
24+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
25+
26+
using TensorDivisibilityInfo =
27+
llvm::SmallDenseMap<unsigned, IREE::Util::ConstantIntDivisibility>;
28+
29+
namespace {
30+
31+
struct RemoveOptimizationBarrier final
32+
: public OpRewritePattern<IREE::Util::OptimizationBarrierOp> {
33+
using OpRewritePattern::OpRewritePattern;
34+
35+
LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp,
36+
PatternRewriter &rewriter) const override {
37+
rewriter.replaceOp(barrierOp, barrierOp.getOperands());
38+
return success();
39+
}
40+
};
41+
42+
/// This pass is used to materialize information about dynamic dimensions of
43+
/// `tensor` operands of an operation in the IR. If a dynamic dimension is
44+
/// known to be a multiple of a compile-time constant value, this pass
45+
/// expands the shape of the operands. For example if a `tensor` operand
46+
/// is of shape `tensor<...x?x...>` and that dimension is known to be a
47+
/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the
48+
/// size of the new dynamic dimension is 1/16-th the size of the original
49+
/// dynamic dimension size. This is done in two steps.
50+
/// 1) Replace operands with such dynamic dimension with the result of a
51+
/// `tensor.expand_shape/tensor.collapse_shape` pair
52+
/// to materialize the new static dimension and immediately fold it away. A
53+
/// optimization barrier is added in between to prevent these operations from
54+
/// being folded.
55+
/// 2) Use patterns that propagate the `tensor.collapse_shape` down to
56+
/// manipulate the operation appropriately. This
57+
/// allows re-using the (fairly complex) logic used to expand dimensions of
58+
/// operations implemented in the propagation patterns.
59+
/// At the end of the pass the optimization barriers are removed to fold away
60+
/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns.
61+
struct BlockDynamicDimensionsPass final
62+
: impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> {
63+
void runOnOperation() override;
64+
};
65+
} // namespace
66+
67+
/// Retrieve the divisibility information for dynamic dimensions of `v` if
68+
/// known.
69+
static TensorDivisibilityInfo
70+
getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis,
71+
Value v) {
72+
TensorDivisibilityInfo divisibilityInfo;
73+
auto tensorType = dyn_cast<RankedTensorType>(v.getType());
74+
if (!tensorType) {
75+
return divisibilityInfo;
76+
}
77+
78+
for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
79+
if (!tensorType.isDynamicDim(index))
80+
continue;
81+
std::optional<IREE::Util::ConstantIntDivisibility> dimDivisibility =
82+
dynamicDimAnalysis.getDivisibilityInfo(v, index);
83+
if (!dimDivisibility)
84+
continue;
85+
divisibilityInfo[index] = std::move(dimDivisibility.value());
86+
}
87+
88+
return divisibilityInfo;
89+
}
90+
91+
/// For a `v` if the dimension is known to be multiple of a compile-time static
92+
/// value, insert
93+
///
94+
/// ```mlir
95+
/// %v_expand = tensor.expand_shape %v
96+
/// %barrier = util.optimization.barrier %v
97+
/// %v_collapse = tensor.collapse_shape %barrier
98+
/// ```
99+
///
100+
/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are
101+
/// inverses of each other. The `util.optimization.barrier` avoid these from
102+
/// getting folded away during reshape propagation. Return the result of the
103+
/// `tensor.collapse_shape generated.
104+
static std::optional<Value>
105+
blockDynamicDimensionsOfValue(RewriterBase &rewriter,
106+
const TensorDivisibilityInfo &divisibilityInfo,
107+
Value v) {
108+
auto tensorType = dyn_cast<RankedTensorType>(v.getType());
109+
if (!tensorType) {
110+
return std::nullopt;
111+
}
112+
113+
// Check if we know that the operands have a divisibility information.
114+
SmallVector<OpFoldResult> outputShape;
115+
SmallVector<ReassociationIndices> reassociation;
116+
Location loc = v.getLoc();
117+
118+
for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) {
119+
reassociation.emplace_back(ReassociationIndices{});
120+
121+
// Check if this needs division.
122+
if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) {
123+
reassociation.back().push_back(outputShape.size());
124+
outputShape.push_back(rewriter.getIndexAttr(dim));
125+
continue;
126+
}
127+
128+
// Split the dynamic based on the divisibility info.
129+
IREE::Util::ConstantIntDivisibility currDivisibility =
130+
divisibilityInfo.lookup(index);
131+
uint64_t factor = currDivisibility.sdiv();
132+
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
133+
AffineExpr divExpr = s0.floorDiv(factor);
134+
Value sourceDim = rewriter.create<tensor::DimOp>(loc, v, index).getResult();
135+
OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply(
136+
rewriter, loc, divExpr, ArrayRef<OpFoldResult>{sourceDim});
137+
OpFoldResult newStaticDim = rewriter.getIndexAttr(factor);
138+
139+
reassociation.back().push_back(outputShape.size());
140+
reassociation.back().push_back(outputShape.size() + 1);
141+
142+
outputShape.push_back(newDynamicDim);
143+
outputShape.push_back(newStaticDim);
144+
}
145+
146+
auto staticOutputShape =
147+
llvm::map_to_vector(outputShape, [](OpFoldResult ofr) {
148+
if (auto staticShapeAttr = dyn_cast<Attribute>(ofr)) {
149+
return cast<IntegerAttr>(staticShapeAttr).getInt();
150+
}
151+
return ShapedType::kDynamic;
152+
});
153+
auto outputType = RankedTensorType::get(
154+
staticOutputShape, tensorType.getElementType(), tensorType.getEncoding());
155+
156+
Value expandShape = rewriter.create<tensor::ExpandShapeOp>(
157+
loc, outputType, v, reassociation, outputShape);
158+
Value barrier =
159+
rewriter.create<IREE::Util::OptimizationBarrierOp>(loc, expandShape)
160+
.getResult(0);
161+
Value collapseShape = rewriter.create<tensor::CollapseShapeOp>(
162+
loc, tensorType, barrier, reassociation);
163+
return collapseShape;
164+
}
165+
166+
/// For an operation, replace the operands at indices specified in
167+
/// `limitToOperandIndices` with the result of
168+
/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the
169+
/// information about dynamic dimensions that are known to be a multiple of a
170+
/// compile-time static value. For example,
171+
///
172+
/// ```mlir
173+
/// %1 = <some_op>(..., %0, ...) : ... , tensor<4x?x6xf32>
174+
/// ```
175+
///
176+
/// If the dynamic dimension is known to be a multiple of 16, then generate
177+
///
178+
/// ```mlir
179+
/// %expanded = tensor.expand_shape %0 :
180+
/// tensor<4x?x5xf32> into tensor<4x?x16x6xf32>
181+
/// %barrier = util.optimization.barrier %expanded
182+
/// %collapsed = tensor.collapse_shape %barrier
183+
/// : tensor<4x?x16x5xf32> into tensor<4x?x5xf32>
184+
/// %1 = <some_op>(..., %collaped, ...) : ... , tensor<4x?x6xf32>
185+
/// ```
186+
static LogicalResult blockDynamicDimensions(
187+
RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis,
188+
Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandIndices) {
189+
OpBuilder::InsertionGuard g(rewriter);
190+
191+
for (OpOperand &operand : operation->getOpOperands()) {
192+
if (!limitToOperandIndices.contains(operand.getOperandNumber()))
193+
continue;
194+
if (operand.get().getDefiningOp<tensor::CollapseShapeOp>())
195+
continue;
196+
TensorDivisibilityInfo operandDivisibilityInfo =
197+
getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get());
198+
if (operandDivisibilityInfo.empty())
199+
continue;
200+
std::optional<Value> newOperand = blockDynamicDimensionsOfValue(
201+
rewriter, operandDivisibilityInfo, operand.get());
202+
if (newOperand) {
203+
rewriter.modifyOpInPlace(operation,
204+
[&]() { operand.set(newOperand.value()); });
205+
}
206+
}
207+
return success();
208+
}
209+
210+
/// Insert `tensor.expand_shape` operations to materialize in IR information
211+
/// about dynamic dimensions that are known to be a multiple of a compile-time
212+
/// know value, for the operands of `iree_linalg_ext.attention` operation.
213+
static LogicalResult
214+
blockDynamicDimensions(RewriterBase &rewriter,
215+
const TensorDynamicDimAnalysis &dynamicDimAnalysis,
216+
IREE::LinalgExt::AttentionOp attentionOp) {
217+
// Only block the q and k values.
218+
llvm::SmallDenseSet<int64_t> prunedOperandsList;
219+
prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber());
220+
prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber());
221+
return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp,
222+
prunedOperandsList);
223+
}
224+
225+
void BlockDynamicDimensionsPass::runOnOperation() {
226+
Operation *operation = getOperation();
227+
MLIRContext *context = &getContext();
228+
TensorDynamicDimAnalysis dynamicDimAnalysis(operation);
229+
if (failed(dynamicDimAnalysis.run())) {
230+
return signalPassFailure();
231+
}
232+
233+
IRRewriter rewriter(context);
234+
auto walkResult = operation->walk(
235+
[&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult {
236+
rewriter.setInsertionPoint(attentionOp);
237+
return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
238+
attentionOp);
239+
});
240+
if (walkResult.wasInterrupted()) {
241+
return signalPassFailure();
242+
}
243+
244+
LLVM_DEBUG({
245+
llvm::dbgs() << "After blocking dimensions:\n";
246+
operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
247+
llvm::dbgs() << "\n";
248+
});
249+
250+
{
251+
RewritePatternSet bubbleExpandShapePatterns(context);
252+
// Add patterns to "push down" the `tensor.collapse_shape` patterns (which
253+
// are the dual of the patterns to "bubble up" `tensor.expand_shape`
254+
// patterns)
255+
linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; };
256+
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
257+
controlFn);
258+
IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
259+
bubbleExpandShapePatterns, controlFn);
260+
// Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and
261+
// "pushed-down" `tensor.collapse_shape` operation with their interface
262+
// bindings or `tensor.empty` operations.
263+
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
264+
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
265+
// Add some additional patterns that can simplify the IR and remove dead
266+
// operations.
267+
memref::populateResolveRankedShapedTypeResultDimsPatterns(
268+
bubbleExpandShapePatterns);
269+
populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns);
270+
if (failed(applyPatternsAndFoldGreedily(
271+
operation, std::move(bubbleExpandShapePatterns)))) {
272+
operation->emitOpError(
273+
"failed in application of bubble up expand shape patterns");
274+
return signalPassFailure();
275+
}
276+
}
277+
278+
LLVM_DEBUG({
279+
llvm::dbgs() << "After reshape propagation:\n";
280+
operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
281+
llvm::dbgs() << "\n";
282+
});
283+
284+
// Delete the optimization barrier and run some further cleanup.
285+
{
286+
RewritePatternSet removeBarrierOpsPatterns(context);
287+
removeBarrierOpsPatterns.insert<RemoveOptimizationBarrier>(context);
288+
tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns,
289+
context);
290+
tensor::CollapseShapeOp::getCanonicalizationPatterns(
291+
removeBarrierOpsPatterns, context);
292+
if (failed(applyPatternsAndFoldGreedily(
293+
operation, std::move(removeBarrierOpsPatterns)))) {
294+
operation->emitOpError("failed in cleanup patterns");
295+
return signalPassFailure();
296+
}
297+
}
298+
299+
return;
300+
}
301+
302+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ iree_cc_library(
7272
"ExtractAddressComputation.h"
7373
"PassUtils.h"
7474
"Passes.h"
75+
"TensorDynamicDimAnalysis.h"
7576
"TileSizeSelection.h"
7677
"Transforms.h"
7778
"UserConfig.h"
7879
SRCS
7980
"AddFastMathFlags.cpp"
81+
"BlockDynamicDimensions.cpp"
8082
"BubbleUpOrdinalOps.cpp"
8183
"BufferizationAnalysis.cpp"
8284
"BufferizeCopyOnlyDispatchesPass.cpp"
@@ -128,6 +130,7 @@ iree_cc_library(
128130
"RemoveSingleIterationLoop.cpp"
129131
"ReplaceSlowMinMaxOps.cpp"
130132
"SplitFullPartialTransferPass.cpp"
133+
"TensorDynamicDimAnalysis.cpp"
131134
"TensorToVectorVectorizePad.cpp"
132135
"TestExecutablePreprocessing.cpp"
133136
"TestPartitionableLoopsInterface.cpp"
@@ -154,6 +157,7 @@ iree_cc_library(
154157
MLIRArithUtils
155158
MLIRBufferizationDialect
156159
MLIRBufferizationTransforms
160+
MLIRDestinationStyleOpInterface
157161
MLIRFuncDialect
158162
MLIRFuncTransforms
159163
MLIRFunctionInterfaces
@@ -203,6 +207,7 @@ iree_cc_library(
203207
iree::compiler::Dialect::HAL::IR
204208
iree::compiler::Dialect::LinalgExt::IR
205209
iree::compiler::Dialect::LinalgExt::Transforms
210+
iree::compiler::Dialect::Util::Analysis
206211
iree::compiler::Dialect::Util::IR
207212
iree::compiler::Utils
208213
PUBLIC

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ def AddFastMathFlagsPass
1919
"given a floating-point mode.";
2020
}
2121

22+
def BlockDynamicDimensionsPass
23+
: Pass<"iree-codegen-block-dynamic-dimensions"> {
24+
let summary = "Expand dynamic dimensions that are known to be multiples of "
25+
"statically known values.";
26+
}
27+
2228
def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
2329
let summary = "Bubbles op ordinal ops to allow for workgroup count computation";
2430
let description = [{

0 commit comments

Comments
 (0)