Skip to content

Commit 1a8d229

Browse files
authored
[GlobalOptimizations] Add a pass to simplify strided contraction-like ops (#20607)
This PR adds a pattern to manipulate generic ops which satisfy all "contraction" conditions except for the indexing maps being projected permutations. Namely, if the input indexing map has results of the form `dim * cst`, this pattern will factor the original generic op into `tensor.extract_slice + contraction linalg.generic`. This addresses #20600 . Each of the included lit test examples were not compiling to mfma instructions before this patch. --------- Signed-off-by: zjgarvey <[email protected]>
1 parent b4e3694 commit 1a8d229

File tree

9 files changed

+275
-0
lines changed

9 files changed

+275
-0
lines changed

compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ iree_compiler_cc_library(
4646
srcs = [
4747
"CleanupNumericNarrowing.cpp",
4848
"Convert1X1FilterConv2DToMatmul.cpp",
49+
"ConvertStridedContractionToContraction.cpp",
4950
"DataLayoutPropagation.cpp",
5051
"DecomposeConcat.cpp",
5152
"DemoteContractionInputsToBF16.cpp",

compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ iree_cc_library(
4242
SRCS
4343
"CleanupNumericNarrowing.cpp"
4444
"Convert1X1FilterConv2DToMatmul.cpp"
45+
"ConvertStridedContractionToContraction.cpp"
4546
"DataLayoutPropagation.cpp"
4647
"DecomposeConcat.cpp"
4748
"DemoteContractionInputsToBF16.cpp"
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/GlobalOptimization/Passes.h"
8+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
9+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10+
#include "mlir/IR/AffineExpr.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
14+
15+
namespace mlir::iree_compiler::GlobalOptimization {
16+
17+
#define GEN_PASS_DEF_CONVERTSTRIDEDCONTRACTIONTOCONTRACTIONPASS
18+
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
19+
20+
namespace {
21+
22+
class ConvertStridedContractionToContraction
23+
: public OpRewritePattern<linalg::GenericOp> {
24+
public:
25+
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
26+
LogicalResult matchAndRewrite(linalg::GenericOp op,
27+
PatternRewriter &rewriter) const override {
28+
// Check if the generic op satisfies all other conditions for being a
29+
// contraction.
30+
if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1)
31+
return failure();
32+
if (op.getNumReductionLoops() == 0)
33+
return failure();
34+
if (!mlir::linalg::detail::isContractionBody(
35+
*op.getBlock(), [](Operation *first, Operation *second) {
36+
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
37+
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)))
38+
return true;
39+
return false;
40+
})) {
41+
return failure();
42+
}
43+
44+
SmallVector<AffineMap> mapRange = op.getIndexingMapsArray();
45+
unsigned inputPos = op.getDpsInputOperand(0)->getOperandNumber();
46+
unsigned filterPos = op.getDpsInputOperand(1)->getOperandNumber();
47+
unsigned resInitPos = op.getDpsInitOperand(0)->getOperandNumber();
48+
AffineMap inputMap = mapRange[inputPos];
49+
AffineMap filterMap = mapRange[filterPos];
50+
AffineMap resultMap = mapRange[resInitPos];
51+
// For now, we are only handling the case where the first input is the
52+
// only non-projected permutation.
53+
if (!filterMap.isProjectedPermutation() ||
54+
!resultMap.isProjectedPermutation()) {
55+
return failure();
56+
}
57+
if (inputMap.isProjectedPermutation())
58+
return failure();
59+
SmallVector<int64_t, 4> staticShape = op.getStaticLoopRanges();
60+
61+
llvm::SmallDenseMap<unsigned, int64_t> strides;
62+
SmallVector<AffineExpr> replacementExprs;
63+
Value input = op.getDpsInputs()[0];
64+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
65+
if (!inputTy)
66+
return failure();
67+
SmallVector<int64_t> inputShape(inputTy.getShape());
68+
replacementExprs.reserve(inputMap.getNumResults());
69+
// Walk through input map and look for expressions of the form `dim * cst`.
70+
for (auto [pos, expr] : llvm::enumerate(inputMap.getResults())) {
71+
// Skip dim exprs and constant exprs.
72+
if (isa<AffineDimExpr>(expr) || isa<AffineConstantExpr>(expr)) {
73+
replacementExprs.push_back(expr);
74+
continue;
75+
}
76+
// Look at binary op expressions.
77+
auto binexpr = dyn_cast<AffineBinaryOpExpr>(expr);
78+
// Fail if we see some unexpected kind of expression.
79+
if (!binexpr)
80+
return failure();
81+
auto rhs = dyn_cast<AffineConstantExpr>(binexpr.getRHS());
82+
auto lhs = dyn_cast<AffineDimExpr>(binexpr.getLHS());
83+
// Binary expressions must be of the form `dim * cst`.
84+
if (!rhs || !lhs || binexpr.getKind() != AffineExprKind::Mul) {
85+
replacementExprs.push_back(expr);
86+
continue;
87+
}
88+
strides.insert(std::pair<unsigned, int64_t>(pos, rhs.getValue()));
89+
int64_t newSize = staticShape[lhs.getPosition()];
90+
if (newSize == ShapedType::kDynamic || newSize == 0)
91+
return failure();
92+
inputShape[pos] = newSize;
93+
replacementExprs.push_back(lhs);
94+
}
95+
96+
// Fail if we don't have any work to do.
97+
if (strides.empty())
98+
return failure();
99+
100+
mapRange[inputPos] =
101+
AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
102+
replacementExprs, op.getContext());
103+
auto sliceTy = RankedTensorType::get(inputShape, inputTy.getElementType());
104+
105+
unsigned rank = inputTy.getRank();
106+
SmallVector<OpFoldResult> vOffset(rank, rewriter.getIndexAttr(0));
107+
SmallVector<OpFoldResult> vSizes;
108+
SmallVector<OpFoldResult> vStride(rank, rewriter.getIndexAttr(1));
109+
Location loc = op.getLoc();
110+
for (unsigned i = 0; i < inputTy.getRank(); i++) {
111+
if (strides.contains(i)) {
112+
vStride[i] = rewriter.getIndexAttr(strides.at(i));
113+
}
114+
if (inputShape[i] != ShapedType::kDynamic) {
115+
vSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
116+
continue;
117+
}
118+
vSizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, input, i));
119+
}
120+
Value extractedSlice = rewriter.create<tensor::ExtractSliceOp>(
121+
loc, sliceTy, input, vOffset, vSizes, vStride);
122+
rewriter.startOpModification(op);
123+
op.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(mapRange));
124+
op.setOperand(0, extractedSlice);
125+
rewriter.finalizeOpModification(op);
126+
return success();
127+
}
128+
};
129+
130+
struct ConvertStridedContractionToContractionPass
131+
: public impl::ConvertStridedContractionToContractionPassBase<
132+
ConvertStridedContractionToContractionPass> {
133+
void getDependentDialects(DialectRegistry &registry) const override {
134+
registry.insert<arith::ArithDialect, tensor::TensorDialect>();
135+
}
136+
137+
void runOnOperation() override {
138+
MLIRContext *context = &getContext();
139+
RewritePatternSet patterns(&getContext());
140+
patterns.insert<ConvertStridedContractionToContraction>(context);
141+
walkAndApplyPatterns(getOperation(), std::move(patterns));
142+
}
143+
};
144+
} // namespace
145+
} // namespace mlir::iree_compiler::GlobalOptimization

compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ void buildGlobalOptimizationPassPipeline(
155155
});
156156

157157
mainPassManager.addPass(DispatchCreation::createFoldUnitExtentDimsPass());
158+
mainPassManager.addPass(
159+
GlobalOptimization::createConvertStridedContractionToContractionPass());
158160
FunctionLikeNest(mainPassManager)
159161
.addPredicatedPass(clEnableFuseSiluHorizontalMatmul,
160162
createFuseSiluHorizontalMatmulPass)

compiler/src/iree/compiler/GlobalOptimization/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def Convert1X1FilterConv2DToMatmulPass:
1919
let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
2020
}
2121

22+
def ConvertStridedContractionToContractionPass:
23+
Pass<"iree-global-opt-convert-strided-contraction-to-contraction", ""> {
24+
let summary = "Factors out an extract_slice from contraction-like ops with strided inputs.";
25+
}
26+
2227
def DecomposeConcatPass :
2328
Pass<"iree-global-opt-decompose-concat", ""> {
2429
let summary = "Decomposes concatenations into a destination and a sequence of slice inserts.";

compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ iree_lit_test_suite(
3434
"propagate_linalg_transpose.mlir",
3535
"raise_special_ops.mlir",
3636
"remove_zero_extent_tensors.mlir",
37+
"strided_contraction_to_contraction.mlir",
3738
"transformation_pipeline.mlir",
3839
"transpose_and_decompose_concat.mlir",
3940
"warn_on_uninitialized_values.mlir",

compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ iree_lit_test_suite(
3232
"propagate_linalg_transpose.mlir"
3333
"raise_special_ops.mlir"
3434
"remove_zero_extent_tensors.mlir"
35+
"strided_contraction_to_contraction.mlir"
3536
"transformation_pipeline.mlir"
3637
"transpose_and_decompose_concat.mlir"
3738
"warn_on_uninitialized_values.mlir"
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: iree-opt --split-input-file --mlir-print-local-scope -iree-global-opt-convert-strided-contraction-to-contraction %s | FileCheck %s
2+
3+
util.func public @strided_from_output_static(%input: tensor<2x118x182x448xbf16>, %filter: tensor<896x448xbf16>) -> tensor<2x59x91x896xf32> {
4+
%cst = arith.constant 0.000000e+00 : f32
5+
%0 = tensor.empty() : tensor<2x59x91x896xf32>
6+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x59x91x896xf32>) -> tensor<2x59x91x896xf32>
7+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, 2 * d1, d2 * 2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%input, %filter : tensor<2x118x182x448xbf16>, tensor<896x448xbf16>) outs(%1 : tensor<2x59x91x896xf32>) {
8+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
9+
%3 = arith.extf %in : bf16 to f32
10+
%4 = arith.extf %in_0 : bf16 to f32
11+
%5 = arith.mulf %3, %4 : f32
12+
%6 = arith.addf %out, %5 : f32
13+
linalg.yield %6 : f32
14+
} -> tensor<2x59x91x896xf32>
15+
util.return %2 : tensor<2x59x91x896xf32>
16+
}
17+
18+
// CHECK-LABEL: @strided_from_output_static(
19+
// CHECK-SAME: %[[INPUT:.*]]: tensor<2x118x182x448xbf16>
20+
// CHECK-SAME: %[[FILTER:.*]]: tensor<896x448xbf16>
21+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INPUT]][0, 0, 0, 0] [2, 59, 91, 448] [1, 2, 2, 1]
22+
// CHECK-SAME: tensor<2x118x182x448xbf16> to tensor<2x59x91x448xbf16>
23+
// CHECK: %[[GEN:.*]] = linalg.generic
24+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
25+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
26+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
27+
// CHECK-SAME: ins(%[[SLICE]], %[[FILTER]]
28+
// CHECK: util.return %[[GEN]]
29+
30+
31+
// -----
32+
33+
util.func public @strided_from_output_dynamic_batch(%input: tensor<?x118x182x448xbf16>, %filter: tensor<896x448xbf16>) -> tensor<?x59x91x896xf32> {
34+
%cst = arith.constant 0.000000e+00 : f32
35+
%c0 = arith.constant 0 : index
36+
%dim = tensor.dim %input, %c0 : tensor<?x118x182x448xbf16>
37+
%0 = tensor.empty(%dim) : tensor<?x59x91x896xf32>
38+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x59x91x896xf32>) -> tensor<?x59x91x896xf32>
39+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 2, d2 * 2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%input, %filter : tensor<?x118x182x448xbf16>, tensor<896x448xbf16>) outs(%1 : tensor<?x59x91x896xf32>) {
40+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
41+
%3 = arith.extf %in : bf16 to f32
42+
%4 = arith.extf %in_0 : bf16 to f32
43+
%5 = arith.mulf %3, %4 : f32
44+
%6 = arith.addf %out, %5 : f32
45+
linalg.yield %6 : f32
46+
} -> tensor<?x59x91x896xf32>
47+
util.return %2 : tensor<?x59x91x896xf32>
48+
}
49+
50+
// CHECK-LABEL: @strided_from_output_dynamic_batch(
51+
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x118x182x448xbf16>
52+
// CHECK-SAME: %[[FILTER:.*]]: tensor<896x448xbf16>
53+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INPUT]][0, 0, 0, 0] [%[[DIM:.*]], 59, 91, 448] [1, 2, 2, 1]
54+
// CHECK-SAME: tensor<?x118x182x448xbf16> to tensor<?x59x91x448xbf16>
55+
// CHECK: %[[GEN:.*]] = linalg.generic
56+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
57+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
58+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
59+
// CHECK-SAME: ins(%[[SLICE]], %[[FILTER]]
60+
// CHECK: util.return %[[GEN]]
61+
62+
// -----
63+
64+
util.func public @strided_from_output_partial_conv(%input: tensor<2x118x182x448xbf16>, %filter: tensor<896x2x448xbf16>) -> tensor<2x59x91x896xf32> {
65+
%cst = arith.constant 0.000000e+00 : f32
66+
%0 = tensor.empty() : tensor<2x59x91x896xf32>
67+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x59x91x896xf32>) -> tensor<2x59x91x896xf32>
68+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2, d2 * 2 + d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<2x118x182x448xbf16>, tensor<896x2x448xbf16>) outs(%1 : tensor<2x59x91x896xf32>) {
69+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
70+
%3 = arith.extf %in : bf16 to f32
71+
%4 = arith.extf %in_0 : bf16 to f32
72+
%5 = arith.mulf %3, %4 : f32
73+
%6 = arith.addf %out, %5 : f32
74+
linalg.yield %6 : f32
75+
} -> tensor<2x59x91x896xf32>
76+
util.return %2 : tensor<2x59x91x896xf32>
77+
}
78+
79+
// CHECK-LABEL: @strided_from_output_partial_conv
80+
// CHECK-SAME: %[[INPUT:.*]]: tensor<2x118x182x448xbf16>
81+
// CHECK-SAME: %[[FILTER:.*]]: tensor<896x2x448xbf16>
82+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INPUT]][0, 0, 0, 0] [2, 59, 182, 448] [1, 2, 1, 1]
83+
// CHECK-SAME: tensor<2x118x182x448xbf16> to tensor<2x59x182x448xbf16>
84+
// CHECK: linalg.generic
85+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d5)>
86+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
87+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
88+
// CHECK-SAME: ins(%[[SLICE]], %[[FILTER]]
89+
90+
// -----
91+
92+
util.func public @strided_from_filter_static(%input: tensor<896x118x16xbf16>, %filter: tensor<448x59x16xbf16>) -> tensor<896x448xf32> {
93+
%cst = arith.constant 0.000000e+00 : f32
94+
%0 = tensor.empty() : tensor<896x448xf32>
95+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<896x448xf32>) -> tensor<896x448xf32>
96+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2 * 2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<896x118x16xbf16>, tensor<448x59x16xbf16>) outs(%1 : tensor<896x448xf32>) {
97+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
98+
%3 = arith.extf %in : bf16 to f32
99+
%4 = arith.extf %in_0 : bf16 to f32
100+
%5 = arith.mulf %3, %4 : f32
101+
%6 = arith.addf %out, %5 : f32
102+
linalg.yield %6 : f32
103+
} -> tensor<896x448xf32>
104+
util.return %2 : tensor<896x448xf32>
105+
}
106+
107+
// CHECK-LABEL: @strided_from_filter_static(
108+
// CHECK-SAME: %[[INPUT:.*]]: tensor<896x118x16xbf16>
109+
// CHECK-SAME: %[[FILTER:.*]]: tensor<448x59x16xbf16>
110+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INPUT]][0, 0, 0] [896, 59, 16] [1, 2, 1]
111+
// CHECK-SAME: tensor<896x118x16xbf16> to tensor<896x59x16xbf16>
112+
// CHECK: %[[GEN:.*]] = linalg.generic
113+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
114+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
115+
// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1)>
116+
// CHECK-SAME: ins(%[[SLICE]], %[[FILTER]]
117+
// CHECK: util.return %[[GEN]]

compiler/src/iree/compiler/Preprocessing/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ buildMakeSingleDispatchPassPipeline(OpPassManager &passManager,
147147
// Generalize transposes and any other remaining named linalg ops that can
148148
// now be represented as generics.
149149
passManager.addPass(GlobalOptimization::createGeneralizeLinalgNamedOpsPass());
150+
passManager.addPass(
151+
GlobalOptimization::createConvertStridedContractionToContractionPass());
150152
passManager.addPass(DispatchCreation::createFusionPreprocessingPass());
151153
passManager.addPass(mlir::createCSEPass());
152154
DispatchCreation::BubbleUpExpandShapesPassOptions bubbleOptions;

0 commit comments

Comments
 (0)