Skip to content

Commit 40c28a1

Browse files
[Codegen][Common] Add InsertBatchDimForBatchlessConv pass for 2D Conv (#23351)
Most of it has been taken from : #21955 (comment) (CC: @hanhanW )- refactored the logic, cleaned up and tested for the generic conv tests locally. Context: Upstream MLIR's linalg::inferConvolutionDims and related APIs (matchConvolutionOpOfType, isaConvolutionOpInterface) expect convolutions to have a batch dimension. These APIs are used by DownscaleConv patterns and vectorization to recognize and optimize convolution operations. However, IREE's dispatch formation pipeline strips unit dimensions (including N=1 batch dimensions) via fold-unit-extent-dims. So after generalization step, a conv_2d_nhwc_hwcf with N=1 becomes a 6-loop generic op instead of the expected 7-loop structure, causing the upstream APIs to fail pattern matching. This pass restores the batch dimension for such "batchless" 2D convolutions by inserting tensor.expand_shape on inputs/outputs and tensor.collapse_shape on results. This allows the upstream convolution detection APIs to recognize the operation, enabling DownscaleConv and vectorization patterns to apply. The reshape operations are propagated to dispatch boundaries and folded into dispatch tensor load/store operations, resulting in zero runtime cost. Lit tests added for supported operations: - Conv2DNhwcHwcf, Conv2DNchwFchw - PoolingNhwcSum/Max/Min, PoolingNhwcMaxUnsigned/MinUnsigned - PoolingNchwSum/Max - DepthwiseConv2DNhwcHwc Signed-off-by: Abhishek Varma <abhvarma@amd.com>
1 parent 7802eb5 commit 40c28a1

File tree

7 files changed

+699
-0
lines changed

7 files changed

+699
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ iree_compiler_cc_library(
124124
"IREEExpandStridedMetadata.cpp",
125125
"IREEInjectAssumeAlignment.cpp",
126126
"IREELoopInvariantCodeMotion.cpp",
127+
"InsertBatchDimForBatchlessConv.cpp",
127128
"InstrumentMemoryAccesses.cpp",
128129
"LinkTuningSpecsPass.cpp",
129130
"LowerExecutableUsingTransformDialect.cpp",
@@ -252,6 +253,7 @@ iree_compiler_cc_library(
252253
"@llvm-project//mlir:LLVMCommonConversion",
253254
"@llvm-project//mlir:LLVMDialect",
254255
"@llvm-project//mlir:LinalgDialect",
256+
"@llvm-project//mlir:LinalgInterfaces",
255257
"@llvm-project//mlir:LinalgTransforms",
256258
"@llvm-project//mlir:LinalgUtils",
257259
"@llvm-project//mlir:LoopLikeInterface",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ iree_cc_library(
117117
"IREEExpandStridedMetadata.cpp"
118118
"IREEInjectAssumeAlignment.cpp"
119119
"IREELoopInvariantCodeMotion.cpp"
120+
"InsertBatchDimForBatchlessConv.cpp"
120121
"InstrumentMemoryAccesses.cpp"
121122
"LinkTuningSpecsPass.cpp"
122123
"LowerExecutableUsingTransformDialect.cpp"
@@ -197,6 +198,7 @@ iree_cc_library(
197198
MLIRLLVMCommonConversion
198199
MLIRLLVMDialect
199200
MLIRLinalgDialect
201+
MLIRLinalgInterfacesIncGenLib
200202
MLIRLinalgTransforms
201203
MLIRLinalgUtils
202204
MLIRLoopLikeInterface
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/IR/AffineExpr.h"
15+
#include "mlir/IR/AffineMap.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
#define DEBUG_TYPE "iree-codegen-insert-batch-dim-for-batchless-conv"
20+
21+
namespace mlir::iree_compiler {
22+
23+
#define GEN_PASS_DEF_INSERTBATCHDIMFORBATCHLESSCONVPASS
24+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
25+
26+
namespace {
27+
28+
/// Detects if the given linalg op is a batch-less convolution (a convolution
29+
/// where the batch dimension N=1 was stripped by IREE's unit dim folding).
30+
///
31+
/// Returns true if this is a batch-less convolution that should be
32+
/// transformed, false otherwise.
33+
static bool isBatchlessConv(linalg::LinalgOp op) {
34+
// Must be a convolution operation (uses upstream MLIR's convolution
35+
// detection).
36+
if (!linalg::isaConvolutionOpInterface(op)) {
37+
return false;
38+
}
39+
40+
FailureOr<linalg::ConvolutionDimensions> maybeConvDims =
41+
linalg::inferConvolutionDims(op);
42+
if (failed(maybeConvDims)) {
43+
return false;
44+
}
45+
46+
// Detect operation kind based on channel dimensions.
47+
bool isRegularConv = !maybeConvDims->inputChannel.empty() &&
48+
!maybeConvDims->outputChannel.empty();
49+
bool isDepthwise = !maybeConvDims->depth.empty();
50+
bool isPooling = maybeConvDims->inputChannel.empty() &&
51+
maybeConvDims->outputChannel.empty() &&
52+
maybeConvDims->depth.empty();
53+
54+
// Check if batch dimension is missing.
55+
bool isBatchless = false;
56+
57+
if (isRegularConv || isDepthwise) {
58+
// For conv/depthwise: batch should be non-empty when present.
59+
isBatchless = maybeConvDims->batch.empty();
60+
} else if (isPooling) {
61+
// For pooling: batch contains [N, C], without N it's just [C].
62+
// So batch.size() == 1 means only channel, no real batch.
63+
isBatchless = (maybeConvDims->batch.size() == 1);
64+
}
65+
66+
return isBatchless;
67+
}
68+
69+
/// Builds reassociation indices for prepending a dimension at position 0.
70+
/// For a tensor of rank R, produces: [[0, 1], [2], [3], ..., [R]]
71+
static SmallVector<ReassociationIndices>
72+
buildPrependDimReassociation(int64_t rank) {
73+
SmallVector<ReassociationIndices> reassoc;
74+
reassoc.push_back({0, 1}); // New dim groups with first existing dim.
75+
for (int64_t i = 1; i < rank; ++i) {
76+
reassoc.push_back({i + 1});
77+
}
78+
return reassoc;
79+
}
80+
81+
/// Expands a tensor by prepending a unit dimension at position 0.
82+
/// tensor<AxBxC> -> tensor<1xAxBxC>
83+
static Value prependUnitDimToTensor(RewriterBase &rewriter, Location loc,
84+
Value tensor) {
85+
auto tensorType = cast<RankedTensorType>(tensor.getType());
86+
SmallVector<int64_t> newShape = {1};
87+
llvm::append_range(newShape, tensorType.getShape());
88+
auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
89+
auto reassoc = buildPrependDimReassociation(tensorType.getRank());
90+
return tensor::ExpandShapeOp::create(rewriter, loc, newType, tensor, reassoc);
91+
}
92+
93+
/// Shifts all existing dimensions in an affine map by 1 and prepends d0.
94+
/// (d0, d1, ...) -> (...) becomes (d0, d1, d2, ...) -> (d0, ...)
95+
static AffineMap shiftAndPrependDimToMap(AffineMap oldMap, MLIRContext *ctx) {
96+
AffineMap shifted = oldMap.shiftDims(1);
97+
SmallVector<AffineExpr> newResults;
98+
newResults.push_back(getAffineDimExpr(0, ctx)); // New leading dim.
99+
llvm::append_range(newResults, shifted.getResults());
100+
return AffineMap::get(oldMap.getNumDims() + 1, 0, newResults, ctx);
101+
}
102+
103+
/// Inserts a unit batch dimension into a batchless convolution operation.
104+
///
105+
/// Transforms:
106+
/// linalg.generic (batchless conv, e.g., HWC -> HWF)
107+
/// Into:
108+
/// expand_shape(input) -> linalg.generic (with batch, NHWC -> NHWF) ->
109+
/// collapse_shape
110+
///
111+
/// Returns the newly created conv op with batch dimension.
112+
static linalg::GenericOp insertUnitBatchDimension(RewriterBase &rewriter,
113+
linalg::GenericOp op) {
114+
Location loc = op.getLoc();
115+
MLIRContext *ctx = rewriter.getContext();
116+
117+
// Process all input operands.
118+
// - Operand 0 (image): expand tensor and prepend dim to map.
119+
// - Other operands (filter, zero points, etc.): keep as-is, shift map.
120+
SmallVector<Value> newInputs;
121+
SmallVector<AffineMap> newMaps;
122+
123+
for (OpOperand *inputOperand : op.getDpsInputOperands()) {
124+
Value input = inputOperand->get();
125+
AffineMap oldMap = op.getMatchingIndexingMap(inputOperand);
126+
127+
if (inputOperand->getOperandNumber() == 0) {
128+
// Image input: expand and prepend dim to map.
129+
newInputs.push_back(prependUnitDimToTensor(rewriter, loc, input));
130+
newMaps.push_back(shiftAndPrependDimToMap(oldMap, ctx));
131+
} else {
132+
// Other inputs (filter, zero points): keep as-is, just shift map.
133+
newInputs.push_back(input);
134+
newMaps.push_back(oldMap.shiftDims(1));
135+
}
136+
}
137+
138+
// Output: expand and add batch dim to map.
139+
OpOperand *outputOperand = op.getDpsInitOperand(0);
140+
Value output = outputOperand->get();
141+
auto outputType = cast<RankedTensorType>(output.getType());
142+
Value expandedOutput = prependUnitDimToTensor(rewriter, loc, output);
143+
AffineMap newOutputMap =
144+
shiftAndPrependDimToMap(op.getMatchingIndexingMap(outputOperand), ctx);
145+
newMaps.push_back(newOutputMap);
146+
147+
// New iterator types: prepend parallel (batch) to existing types.
148+
SmallVector<utils::IteratorType> newIterTypes;
149+
newIterTypes.push_back(utils::IteratorType::parallel);
150+
llvm::append_range(newIterTypes, op.getIteratorTypesArray());
151+
152+
// Create new generic with batch dimension.
153+
auto newOutputType = cast<RankedTensorType>(expandedOutput.getType());
154+
auto newConvOp = linalg::GenericOp::create(
155+
rewriter, loc, TypeRange{newOutputType}, newInputs,
156+
ValueRange{expandedOutput}, newMaps, newIterTypes,
157+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
158+
IRMapping mapping;
159+
for (auto [oldArg, newArg] :
160+
llvm::zip(op.getBody()->getArguments(), args)) {
161+
mapping.map(oldArg, newArg);
162+
}
163+
for (Operation &bodyOp : op.getBody()->without_terminator()) {
164+
b.clone(bodyOp, mapping);
165+
}
166+
auto yield = cast<linalg::YieldOp>(op.getBody()->getTerminator());
167+
linalg::YieldOp::create(b, nestedLoc,
168+
mapping.lookup(yield.getOperand(0)));
169+
});
170+
171+
// Collapse result to remove the batch dimension we added.
172+
auto reassoc = buildPrependDimReassociation(outputType.getRank());
173+
auto collapsed = tensor::CollapseShapeOp::create(
174+
rewriter, loc, outputType, newConvOp.getResult(0), reassoc);
175+
176+
rewriter.replaceOp(op, collapsed);
177+
return newConvOp;
178+
}
179+
180+
//===----------------------------------------------------------------------===//
181+
// Pass definition
182+
//===----------------------------------------------------------------------===//
183+
184+
struct InsertBatchDimForBatchlessConvPass final
185+
: impl::InsertBatchDimForBatchlessConvPassBase<
186+
InsertBatchDimForBatchlessConvPass> {
187+
void runOnOperation() override {
188+
MLIRContext *context = &getContext();
189+
190+
// Find the batchless conv op. We assume there is only one convolution-like
191+
// op per function (typical for dispatches).
192+
linalg::GenericOp batchlessConv = nullptr;
193+
getOperation()->walk([&](linalg::GenericOp op) {
194+
if (isBatchlessConv(op)) {
195+
if (batchlessConv) {
196+
// Multiple batchless convs found - bail out.
197+
batchlessConv = nullptr;
198+
return WalkResult::interrupt();
199+
}
200+
batchlessConv = op;
201+
}
202+
return WalkResult::advance();
203+
});
204+
205+
if (!batchlessConv) {
206+
return;
207+
}
208+
209+
// Insert unit batch dimension into the convolution.
210+
IRRewriter rewriter(context);
211+
rewriter.setInsertionPoint(batchlessConv);
212+
linalg::GenericOp newConvOp =
213+
insertUnitBatchDimension(rewriter, batchlessConv);
214+
215+
// Phase 1: Bubble up expand_shape (only for ops BEFORE conv).
216+
{
217+
RewritePatternSet reshapePatterns(context);
218+
populateReshapeToInterfaceTensorPatterns(reshapePatterns);
219+
220+
linalg::ControlFusionFn controlFn = [&](OpOperand *fusedOperand) {
221+
Operation *op = fusedOperand->getOwner();
222+
if (op->getBlock() != newConvOp->getBlock()) {
223+
return false;
224+
}
225+
return op->isBeforeInBlock(newConvOp);
226+
};
227+
228+
linalg::populateFoldReshapeOpsByExpansionPatterns(reshapePatterns,
229+
controlFn);
230+
tensor::populateFoldTensorEmptyPatterns(reshapePatterns);
231+
tensor::populateBubbleUpExpandShapePatterns(reshapePatterns);
232+
linalg::FillOp::getCanonicalizationPatterns(reshapePatterns, context);
233+
tensor::ExpandShapeOp::getCanonicalizationPatterns(reshapePatterns,
234+
context);
235+
236+
if (failed(applyPatternsGreedily(getOperation(),
237+
std::move(reshapePatterns)))) {
238+
return signalPassFailure();
239+
}
240+
}
241+
242+
// Phase 2: Sink down collapse_shape (only for ops AFTER conv).
243+
{
244+
RewritePatternSet reshapePatterns(context);
245+
populateReshapeToInterfaceTensorPatterns(reshapePatterns);
246+
247+
linalg::ControlFusionFn controlFn = [&](OpOperand *fusedOperand) {
248+
Operation *op = fusedOperand->getOwner();
249+
if (op->getBlock() != newConvOp->getBlock()) {
250+
return false;
251+
}
252+
return newConvOp->isBeforeInBlock(op);
253+
};
254+
255+
linalg::populateFoldReshapeOpsByExpansionPatterns(reshapePatterns,
256+
controlFn);
257+
tensor::populateFoldTensorEmptyPatterns(reshapePatterns);
258+
tensor::CollapseShapeOp::getCanonicalizationPatterns(reshapePatterns,
259+
context);
260+
261+
if (failed(applyPatternsGreedily(getOperation(),
262+
std::move(reshapePatterns)))) {
263+
return signalPassFailure();
264+
}
265+
}
266+
}
267+
};
268+
269+
} // namespace
270+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,25 @@ def InstrumentMemoryAccessesPass :
745745
let summary = "Instruments memory reads and writes for address tracking when dispatch instrumentation is enabled.";
746746
}
747747

748+
def InsertBatchDimForBatchlessConvPass :
749+
InterfacePass<"iree-codegen-insert-batch-dim-for-batchless-conv", "mlir::FunctionOpInterface"> {
750+
let summary = "Inserts batch dimension for batch-less convolutions.";
751+
let description = [{
752+
Detects convolution operations that have been generalized and had their
753+
batch dimension (N=1) stripped by upstream IREE transformations. It
754+
restores the batch dimension by inserting tensor.expand_shape and
755+
tensor.collapse_shape operations, enabling downstream convolution-specific
756+
patterns (like DownscaleConv and vectorization) to match and apply.
757+
758+
The reshape operations are propagated to the boundary and folded into
759+
dispatch tensor load/store operations, resulting in zero runtime cost.
760+
}];
761+
let dependentDialects = [
762+
"linalg::LinalgDialect",
763+
"tensor::TensorDialect"
764+
];
765+
}
766+
748767
def LinkTuningSpecsPass : Pass<"iree-codegen-link-tuning-specs", "ModuleOp"> {
749768
let summary =
750769
"Link nested transform dialect tuning specs named sequences into a single entry point";

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ iree_lit_test_suite(
6666
"generic_vectorization.mlir",
6767
"hoist_statically_bound_allocations.mlir",
6868
"hoist_unrolled_vector_extract_insert_slice.mlir",
69+
"insert_batch_dim_for_batchless_conv.mlir",
6970
"iree_codegen_canonicalize.mlir",
7071
"iree_comprehensive_bufferize.mlir",
7172
"iree_expand_strided_metadata.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ iree_lit_test_suite(
6161
"generic_vectorization.mlir"
6262
"hoist_statically_bound_allocations.mlir"
6363
"hoist_unrolled_vector_extract_insert_slice.mlir"
64+
"insert_batch_dim_for_batchless_conv.mlir"
6465
"iree_codegen_canonicalize.mlir"
6566
"iree_comprehensive_bufferize.mlir"
6667
"iree_expand_strided_metadata.mlir"

0 commit comments

Comments
 (0)