Skip to content

Commit 8e374a6

Browse files
authored
[CPU][DT] Implement data layout propagation for CPU dispatches. (#21554)
The revision implements a specialized propagation pattern for `tensor.collapse_shape->linalg.unpack`, if the packed dimension and the corresponding inner dimension are collapsed. In the data layout propagation, we also populate patterns that sink down tensor.collapse_shape across `linalg.generic` ops. Because how we materialize matvec in CPU backends is converting it to `linalg.mmt4d->tensor.collapse_shape` op chain. At the end, the pass folds the reshapes into bindings. Fixes #21180 --------- Signed-off-by: hanhanW <[email protected]>
1 parent d6cdf25 commit 8e374a6

File tree

10 files changed

+345
-0
lines changed

10 files changed

+345
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ iree_compiler_cc_library(
5151
srcs = [
5252
"CPULowerToUKernels.cpp",
5353
"CPUPrepareUkernels.cpp",
54+
"CPUPropagateDataLayout.cpp",
5455
"Passes.cpp",
5556
],
5657
hdrs = [
@@ -78,6 +79,7 @@ iree_compiler_cc_library(
7879
"@llvm-project//mlir:BufferizationDialect",
7980
"@llvm-project//mlir:BufferizationInterfaces",
8081
"@llvm-project//mlir:DestinationStyleOpInterface",
82+
"@llvm-project//mlir:DialectUtils",
8183
"@llvm-project//mlir:FuncDialect",
8284
"@llvm-project//mlir:FunctionInterfaces",
8385
"@llvm-project//mlir:IR",

compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ iree_cc_library(
4343
SRCS
4444
"CPULowerToUKernels.cpp"
4545
"CPUPrepareUkernels.cpp"
46+
"CPUPropagateDataLayout.cpp"
4647
"Passes.cpp"
4748
DEPS
4849
::PassHeaders
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
9+
#include "llvm/ADT/STLExtras.h"
10+
#include "llvm/Support/Casting.h"
11+
#include "llvm/Support/LogicalResult.h"
12+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
13+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
17+
#include "mlir/Dialect/Utils/IndexingUtils.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace mlir::iree_compiler {
21+
22+
#define GEN_PASS_DEF_CPUPROPAGATEDATALAYOUTPASS
23+
#include "iree/compiler/Codegen/Common/CPU/Passes.h.inc"
24+
25+
namespace {
26+
27+
/// Sinks down tensor.collapse_shape across linalg.unpack op, if the collapsing
28+
/// dims are two unit dims where one is outer dimension and the other is inner
29+
/// dimension. It implies that we swap two operations by adjusting the packing
30+
/// metadata in linalg.unpack op.
31+
/// Note that the pattern only supports the case where the destination tensor of
32+
/// linalg.unpack op is a tensor.empty op. The constraint can be removed by
33+
/// introducing tensor.expand_shape op on the destination tensor. However, it is
34+
/// not common in practice, so it is not supported now.
35+
struct SinkDownCollapsingUnitDimsAcrossUnpack final
36+
: public OpRewritePattern<linalg::UnPackOp> {
37+
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
38+
LogicalResult matchAndRewrite(linalg::UnPackOp op,
39+
PatternRewriter &rewriter) const override {
40+
if (!isIdentityPermutation(op.getOuterDimsPerm())) {
41+
return rewriter.notifyMatchFailure(
42+
op, "expected identity (or unset) outer permutation");
43+
}
44+
if (op.getSourceRank() != op.getDestRank() + 1) {
45+
return rewriter.notifyMatchFailure(
46+
op, "expected unpacking exactly one dimension");
47+
}
48+
auto emptyOp = op.getDest().getDefiningOp<tensor::EmptyOp>();
49+
if (!emptyOp) {
50+
return rewriter.notifyMatchFailure(
51+
op, "expected destination to be a tensor.empty op");
52+
}
53+
auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
54+
if (!collapseOp) {
55+
return rewriter.notifyMatchFailure(
56+
op, "expected the source to be a tensor.collapse_shape op");
57+
}
58+
59+
SmallVector<ReassociationIndices, 4> ri =
60+
collapseOp.getReassociationIndices();
61+
ReassociationIndices outerRi, innerRi;
62+
for (ArrayRef<int64_t> indices : ri) {
63+
if (indices.size() == 1) {
64+
continue;
65+
}
66+
if (indices.size() > 2) {
67+
return rewriter.notifyMatchFailure(
68+
op, "expected re-association map to have two dimensions");
69+
}
70+
if (outerRi.empty()) {
71+
outerRi.assign(indices.begin(), indices.end());
72+
continue;
73+
}
74+
if (innerRi.empty()) {
75+
innerRi.assign(indices.begin(), indices.end());
76+
continue;
77+
}
78+
return rewriter.notifyMatchFailure(
79+
op, "expected only two re-association maps to have two dimensions");
80+
}
81+
if (outerRi.empty() || innerRi.empty()) {
82+
return rewriter.notifyMatchFailure(
83+
op, "expected only two re-association maps to have two dimensions");
84+
}
85+
86+
RankedTensorType srcType = collapseOp.getSrcType();
87+
if (innerRi.back() != srcType.getRank() - 1) {
88+
return rewriter.notifyMatchFailure(
89+
op, "expected that the two innermost dimensions are collapsed");
90+
}
91+
SmallVector<int64_t> innerDimPos(op.getInnerDimsPos());
92+
if (!llvm::is_contained(outerRi, innerDimPos[0])) {
93+
return rewriter.notifyMatchFailure(
94+
op, "expected the packed dimension is collapsed");
95+
}
96+
97+
bool missLeadingUnitDim = srcType.getDimSize(outerRi[0]) == 1 &&
98+
srcType.getDimSize(innerRi[0]) == 1;
99+
bool missTrailingUnitDim = srcType.getDimSize(outerRi[1]) == 1 &&
100+
srcType.getDimSize(innerRi[1]) == 1;
101+
if (!missLeadingUnitDim && !missTrailingUnitDim) {
102+
return rewriter.notifyMatchFailure(op,
103+
"expected collapsing either leading "
104+
"unit dims or trailing outer dims");
105+
}
106+
107+
// We either add unit dims right before or after the packed dimensions.
108+
// E.g., AxBxNxCxDxn becomes AxBx1xNxCxDx1xn if `missLeadingUnitDim` is
109+
// true. It becomes AxBxNx1xCxDxnx1 if `missingTrailingUnitDim` is true.
110+
// If both are true, the former is prioritized because it does not matter in
111+
// practice.
112+
SmallVector<OpFoldResult> innerTiles(op.getMixedTiles());
113+
SmallVector<OpFoldResult> destShape = emptyOp.getMixedSizes();
114+
if (missLeadingUnitDim) {
115+
// The unit dim is inserted before the packed dimension, so we advance one
116+
// for innerDimPos[0].
117+
innerDimPos[0]++;
118+
innerDimPos.insert(innerDimPos.begin(), outerRi[0]);
119+
innerTiles.insert(innerTiles.begin(), rewriter.getIndexAttr(1));
120+
destShape.insert(destShape.begin() + outerRi[0],
121+
rewriter.getIndexAttr(1));
122+
} else {
123+
innerDimPos.insert(innerDimPos.end(), outerRi[1]);
124+
innerTiles.insert(innerTiles.end(), rewriter.getIndexAttr(1));
125+
destShape.insert(destShape.end(), rewriter.getIndexAttr(1));
126+
}
127+
128+
Location loc = op.getLoc();
129+
auto newDestOp = rewriter.create<tensor::EmptyOp>(
130+
loc, destShape, emptyOp.getType().getElementType());
131+
auto newUnpackOp = rewriter.create<linalg::UnPackOp>(
132+
loc, collapseOp.getSrc(), newDestOp, innerDimPos, innerTiles);
133+
SmallVector<ReassociationIndices> newRi;
134+
for (int64_t i = 0, e = op.getDestRank(); i < e; ++i) {
135+
if (i == outerRi[0]) {
136+
newRi.push_back(outerRi);
137+
++i;
138+
} else {
139+
newRi.push_back({i});
140+
}
141+
}
142+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
143+
op, newUnpackOp.getResult(), newRi);
144+
145+
return success();
146+
}
147+
};
148+
149+
struct CPUPropagateDataLayoutPass final
150+
: public impl::CPUPropagateDataLayoutPassBase<CPUPropagateDataLayoutPass> {
151+
void getDependentDialects(DialectRegistry &registry) const override {
152+
registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
153+
}
154+
155+
void runOnOperation() override;
156+
};
157+
158+
} // namespace
159+
160+
void CPUPropagateDataLayoutPass::runOnOperation() {
161+
MLIRContext *ctx = &getContext();
162+
FunctionOpInterface funcOp = getOperation();
163+
RewritePatternSet patterns(ctx);
164+
patterns.insert<SinkDownCollapsingUnitDimsAcrossUnpack>(ctx);
165+
populateReshapeToInterfaceTensorPatterns(patterns);
166+
tensor::populateFoldTensorEmptyPatterns(patterns, /*foldSingleUseOnly=*/1);
167+
linalg::populateFoldReshapeOpsByExpansionPatterns(
168+
patterns, [](OpOperand *fusedOperand) -> bool {
169+
Operation *producer = fusedOperand->get().getDefiningOp();
170+
auto consumerGenericOp =
171+
dyn_cast_if_present<linalg::GenericOp>(fusedOperand->getOwner());
172+
if (!isa<tensor::CollapseShapeOp>(producer) || !consumerGenericOp) {
173+
return false;
174+
}
175+
return true;
176+
});
177+
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
178+
return signalPassFailure();
179+
}
180+
}
181+
182+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,10 @@ def CPUPrepareUkernelsPass :
3030
"For example, batch_mmt4d ops are decomposed to mmt4d ops";
3131
}
3232

33+
def CPUPropagateDataLayoutPass :
34+
InterfacePass<"iree-codegen-cpu-propagate-data-layout", "mlir::FunctionOpInterface"> {
35+
let summary = "Propagates pack/unpack/reshape ops to make the whole dispatch "
36+
"use the same layout.";
37+
}
38+
3339
#endif // IREE_CODEGEN_COMMON_CPU_PASSES

compiler/src/iree/compiler/Codegen/Common/CPU/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ iree_lit_test_suite(
2121
[
2222
"lower_to_ukernel_ops.mlir",
2323
"prepare_ukernels.mlir",
24+
"propagate_data_layout.mlir",
2425
],
2526
include = ["*.mlir"],
2627
),

compiler/src/iree/compiler/Codegen/Common/CPU/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
SRCS
1717
"lower_to_ukernel_ops.mlir"
1818
"prepare_ukernels.mlir"
19+
"propagate_data_layout.mlir"
1920
TOOLS
2021
FileCheck
2122
iree-opt
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-propagate-data-layout))" --split-input-file %s | FileCheck %s
2+
3+
func.func @collapsing_unit_dim_0(%src: tensor<1x2x1x16xi32>) -> tensor<20xi32> {
4+
%collapsed = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<1x2x1x16xi32> into tensor<2x16xi32>
5+
%1 = tensor.empty() : tensor<20xi32>
6+
%unpack = linalg.unpack %collapsed inner_dims_pos = [0] inner_tiles = [16] into %1 : tensor<2x16xi32> -> tensor<20xi32>
7+
return %unpack : tensor<20xi32>
8+
}
9+
// CHECK-LABEL: func.func @collapsing_unit_dim_0(
10+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
11+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
12+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [1, 16]
13+
// CHECK-SAME: : tensor<1x2x1x16xi32> -> tensor<1x20xi32>
14+
// CHECK-NEXT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]]
15+
// CHECK-SAME: : tensor<1x20xi32> into tensor<20xi32>
16+
// CHECK: return %[[COLLAPSED]]
17+
18+
// -----
19+
20+
func.func @collapsing_unit_dim_1(%src: tensor<?x1x1x1x16xi32>, %batch_size: index) -> tensor<?x3xi32> {
21+
%collapsed = tensor.collapse_shape %src [[0], [1, 2], [3, 4]] : tensor<?x1x1x1x16xi32> into tensor<?x1x16xi32>
22+
%0 = tensor.empty(%batch_size) : tensor<?x3xi32>
23+
%unpack = linalg.unpack %collapsed inner_dims_pos = [1] inner_tiles = [16] into %0 : tensor<?x1x16xi32> -> tensor<?x3xi32>
24+
return %unpack : tensor<?x3xi32>
25+
}
26+
// CHECK-LABEL: func.func @collapsing_unit_dim_1(
27+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
28+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
29+
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [1, 16]
30+
// CHECK-SAME: : tensor<?x1x1x1x16xi32> -> tensor<?x1x3xi32>
31+
// CHECK-NEXT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]]
32+
// CHECK-SAME: : tensor<?x1x3xi32> into tensor<?x3xi32>
33+
// CHECK: return %[[COLLAPSED]]
34+
35+
// -----
36+
37+
#map = affine_map<(d0, d1) -> (d0, d1)>
38+
func.func @collapsing_unit_dim_0_elem_unpack(%src: tensor<1x1x1x16xi32>) -> tensor<3xi32> {
39+
%0 = tensor.empty() : tensor<1x16xi32>
40+
%collapsed = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<1x1x1x16xi32> into tensor<1x16xi32>
41+
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<1x16xi32>) outs(%0 : tensor<1x16xi32>) {
42+
^bb0(%in: i32, %out: i32):
43+
%3 = arith.addi %in, %in : i32
44+
linalg.yield %3 : i32
45+
} -> tensor<1x16xi32>
46+
%2 = tensor.empty() : tensor<3xi32>
47+
%unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [16] into %2 : tensor<1x16xi32> -> tensor<3xi32>
48+
return %unpack : tensor<3xi32>
49+
}
50+
// CHECK-LABEL: func.func @collapsing_unit_dim_0_elem_unpack(
51+
// CHECK: %[[ELEM:.+]] = linalg.generic
52+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ELEM]]
53+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [1, 16]
54+
// CHECK-SAME: : tensor<1x1x1x16xi32> -> tensor<1x3xi32>
55+
// CHECK-NEXT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]]
56+
// CHECK-SAME: : tensor<1x3xi32> into tensor<3xi32>
57+
// CHECK: return %[[COLLAPSED]]
58+
59+
// -----
60+
61+
func.func @negative_unpack_with_outer_dims_perm(%src: tensor<1x1x?x1x16xi32>, %batch_size: index) -> tensor<?x3xi32> {
62+
%collapsed = tensor.collapse_shape %src [[0], [1, 2], [3, 4]] : tensor<1x1x?x1x16xi32> into tensor<1x?x16xi32>
63+
%0 = tensor.empty(%batch_size) : tensor<?x3xi32>
64+
%unpack = linalg.unpack %collapsed outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [16] into %0 : tensor<1x?x16xi32> -> tensor<?x3xi32>
65+
return %unpack : tensor<?x3xi32>
66+
}
67+
// CHECK-LABEL: func.func @negative_unpack_with_outer_dims_perm(
68+
// CHECK: tensor.collapse_shape
69+
// CHECK: linalg.unpack
70+
71+
// -----
72+
73+
func.func @negative_unpack_multiple_dims(%src: tensor<?x1x1x1x16x8xi32>, %d0: index, %d1: index) -> tensor<?x?xi32> {
74+
%collapsed = tensor.collapse_shape %src [[0], [1, 2], [3, 4], [5]] : tensor<?x1x1x1x16x8xi32> into tensor<?x1x16x8xi32>
75+
%0 = tensor.empty(%d0, %d1) : tensor<?x?xi32>
76+
%unpack = linalg.unpack %collapsed inner_dims_pos = [0, 1] inner_tiles = [16, 8] into %0 : tensor<?x1x16x8xi32> -> tensor<?x?xi32>
77+
return %unpack : tensor<?x?xi32>
78+
}
79+
// CHECK-LABEL: func.func @negative_unpack_multiple_dims(
80+
// CHECK: tensor.collapse_shape
81+
// CHECK: linalg.unpack
82+
83+
// -----
84+
85+
func.func @negative_unpack_non_collapsed_dim(%src: tensor<?x1x1x1x16xi32>, %d0: index) -> tensor<?x1xi32> {
86+
%collapsed = tensor.collapse_shape %src [[0], [1, 2], [3, 4]] : tensor<?x1x1x1x16xi32> into tensor<?x1x16xi32>
87+
%0 = tensor.empty(%d0) : tensor<?x1xi32>
88+
%unpack = linalg.unpack %collapsed inner_dims_pos = [0] inner_tiles = [16] into %0 : tensor<?x1x16xi32> -> tensor<?x1xi32>
89+
return %unpack : tensor<?x1xi32>
90+
}
91+
// CHECK-LABEL: func.func @negative_unpack_non_collapsed_dim(
92+
// CHECK: tensor.collapse_shape
93+
// CHECK: linalg.unpack
94+
95+
// -----
96+
97+
func.func @negative_both_m_n_non_unit_dim(%src: tensor<3x4x2x8xi32>) -> tensor<180xi32> {
98+
%collapsed = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<3x4x2x8xi32> into tensor<12x16xi32>
99+
%1 = tensor.empty() : tensor<180xi32>
100+
%unpack = linalg.unpack %collapsed inner_dims_pos = [0] inner_tiles = [16] into %1 : tensor<12x16xi32> -> tensor<180xi32>
101+
return %unpack : tensor<180xi32>
102+
}
103+
// CHECK-LABEL: func.func @negative_both_m_n_non_unit_dim(
104+
// CHECK: tensor.collapse_shape
105+
// CHECK: linalg.unpack
106+
107+
// -----
108+
109+
func.func @negative_innermost_dim_is_not_collapsed(%src: tensor<1x3x1x8x16xi32>) -> tensor<48x8xi32> {
110+
%collapsed = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] : tensor<1x3x1x8x16xi32> into tensor<3x8x16xi32>
111+
%1 = tensor.empty() : tensor<48x8xi32>
112+
%unpack = linalg.unpack %collapsed inner_dims_pos = [0] inner_tiles = [16] into %1 : tensor<3x8x16xi32> -> tensor<48x8xi32>
113+
return %unpack : tensor<48x8xi32>
114+
}
115+
// CHECK-LABEL: func.func @negative_innermost_dim_is_not_collapsed(
116+
// CHECK: tensor.collapse_shape
117+
// CHECK: linalg.unpack

compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ void buildLLVMCPUCodegenConfigurationPassPipelineImpl(
843843
// way to late and should insted be be done during lowering to LLVM.
844844
.addPass(createExpandF16OpToF32Pass)
845845
.addPass(createMaterializeDeviceEncodingPass)
846+
.addPass(createCPUPropagateDataLayoutPass)
846847
.addPass(createConvertAccGEMMToGEMMPass)
847848
// TODO: Remove the following pass the plumb support for
848849
// #hal.descriptor_type memory space through the stack.

tests/e2e/linalg/BUILD.bazel

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ iree_check_single_backend_test_suite(
5151
target_backend = "llvm-cpu",
5252
)
5353

54+
# TODO(#19378): Delete the test suite once data-tiling fusion is default on.
55+
iree_check_single_backend_test_suite(
56+
name = "check_llvm-cpu_dt_fusion_local-task",
57+
srcs = ["narrow_n_matmuls.mlir"],
58+
compiler_flags = [
59+
"--iree-dispatch-creation-experimental-data-tiling",
60+
"--iree-llvmcpu-target-cpu=generic",
61+
"--iree-opt-data-tiling=false",
62+
],
63+
driver = "local-task",
64+
tags = [
65+
# subbyte support for wasm is not on priorities.
66+
"nowasm",
67+
],
68+
target_backend = "llvm-cpu",
69+
)
70+
5471
VMVX_SRCS = enforce_glob(
5572
# keep sorted
5673
[

0 commit comments

Comments
 (0)