Skip to content

Commit 7724306

Browse files
authored
[Encoding] Implement matmul_k encoding propagation across reshapes. (iree-org#20367)
The revision ports the sdxl propagation effort to the main branch. Ideally, we should implement it using interfaces and data-flow analysis. It is a first step of the propagation, and we will incrementally enhance the encoding propagation pass. Co-authored-by: MaheshRavishankar [[email protected]](mailto:[email protected]) --------- Signed-off-by: hanhanW <[email protected]>
1 parent 918244c commit 7724306

File tree

8 files changed

+164
-0
lines changed

8 files changed

+164
-0
lines changed

compiler/src/iree/compiler/DispatchCreation/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ iree_compiler_cc_library(
3535
"HoistEncodingOps.cpp",
3636
"MaterializeDefaultWorkgroupCountRegion.cpp",
3737
"Passes.cpp",
38+
"PropagateEncodings.cpp",
3839
"SetEncoding.cpp",
3940
"SinkReshapes.cpp",
4041
"SplitReduction.cpp",

compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ iree_cc_library(
3737
"HoistEncodingOps.cpp"
3838
"MaterializeDefaultWorkgroupCountRegion.cpp"
3939
"Passes.cpp"
40+
"PropagateEncodings.cpp"
4041
"SetEncoding.cpp"
4142
"SinkReshapes.cpp"
4243
"SplitReduction.cpp"

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
253253
return DispatchCreation::createHoistEncodingOpsPass(
254254
HoistEncodingOpsPassOptions{clHoistEncodingsForConstExpr});
255255
})
256+
.addPass(DispatchCreation::createPropagateEncodingsPass)
256257
.addPass(
257258
DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass);
258259
}

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,14 @@ def HoistEncodingOpsPass :
297297
];
298298
}
299299

300+
def PropagateEncodingsPass :
301+
InterfacePass<"iree-dispatch-creation-propagate-encodings", "mlir::FunctionOpInterface"> {
302+
let summary = "Propagate encodings across other operations.";
303+
let dependentDialects = [
304+
"mlir::tensor::TensorDialect",
305+
"IREE::Encoding::IREEEncodingDialect",
306+
];
307+
}
300308

301309
def SetEncodingPass :
302310
InterfacePass<"iree-dispatch-creation-set-encoding", "mlir::FunctionOpInterface"> {
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
8+
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
10+
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
11+
#include "iree/compiler/DispatchCreation/Passes.h"
12+
#include "llvm/ADT/STLExtras.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Interfaces/FunctionInterfaces.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
#define DEBUG_TYPE "iree-dispatch-creation-propagate-encodings"
20+
21+
namespace mlir::iree_compiler::DispatchCreation {
22+
23+
#define GEN_PASS_DEF_PROPAGATEENCODINGSPASS
24+
#include "iree/compiler/DispatchCreation/Passes.h.inc"
25+
26+
namespace {
27+
28+
/// Pattern to swap `tensor.collapse_shape` -> `iree_encoding.set_encoding`
29+
struct SwapEncodingOpWithTensorCollapseShapeOp
30+
: public OpRewritePattern<IREE::Encoding::SetEncodingOp> {
31+
using Base = OpRewritePattern<IREE::Encoding::SetEncodingOp>;
32+
using Base::Base;
33+
LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp,
34+
PatternRewriter &rewriter) const override;
35+
};
36+
37+
// TODO(#20179): Support the propagation through interfaces. It is supposed to
38+
// be done with data-flow analysis.
39+
struct PropagateEncodingsPass
40+
: public DispatchCreation::impl::PropagateEncodingsPassBase<
41+
PropagateEncodingsPass> {
42+
void runOnOperation() override;
43+
};
44+
45+
} // namespace
46+
47+
LogicalResult SwapEncodingOpWithTensorCollapseShapeOp::matchAndRewrite(
48+
IREE::Encoding::SetEncodingOp encodingOp, PatternRewriter &rewriter) const {
49+
auto encoding = dyn_cast<IREE::Encoding::MatmulKAttr>(
50+
encodingOp.getResultType().getEncoding());
51+
if (!encoding) {
52+
return rewriter.notifyMatchFailure(encodingOp, "only matmul_k is handled");
53+
}
54+
auto collapseOp =
55+
encodingOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
56+
if (!collapseOp) {
57+
return rewriter.notifyMatchFailure(encodingOp,
58+
"expected a collapse_shape producer");
59+
}
60+
if (!IREE::Flow::isNonNullAndOutsideDispatch(encodingOp) ||
61+
!IREE::Flow::isNonNullAndOutsideDispatch(collapseOp)) {
62+
return rewriter.notifyMatchFailure(
63+
encodingOp, "expected that both operations are outside dispatch");
64+
}
65+
66+
ArrayRef<int32_t> kDims = encoding.getKDims().asArrayRef();
67+
llvm::SetVector<int32_t> kDimsSet(kDims.begin(), kDims.end());
68+
69+
// Bail out if it is not propagable.
70+
// TODO: Relax the check to allow transforming innermost reduction dimensions.
71+
// We need to revisit the matmul_k encoding semantic.
72+
SmallVector<ReassociationIndices, 4> reassociationMaps =
73+
collapseOp.getReassociationIndices();
74+
for (int32_t k : kDims) {
75+
if (reassociationMaps[k].size() != 1) {
76+
return rewriter.notifyMatchFailure(
77+
encodingOp,
78+
"expected collaps_shape ops to not transform k dimensions");
79+
}
80+
}
81+
82+
// Get a mapping from original iteration space to expanded iteration space.
83+
SmallVector<int32_t> newKDims;
84+
for (int32_t kDim : kDims) {
85+
newKDims.append(reassociationMaps[kDim].begin(),
86+
reassociationMaps[kDim].end());
87+
}
88+
89+
// Create the new encoding op.
90+
MLIRContext *ctx = rewriter.getContext();
91+
auto newEncodingAttr = IREE::Encoding::MatmulKAttr::get(ctx, newKDims);
92+
RankedTensorType newEncodingType =
93+
collapseOp.getSrcType().cloneWithEncoding(newEncodingAttr);
94+
Value newEncodingOp = rewriter.create<IREE::Encoding::SetEncodingOp>(
95+
encodingOp.getLoc(), newEncodingType, collapseOp.getSrc());
96+
Value newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
97+
collapseOp.getLoc(), encodingOp.getResultType(), newEncodingOp,
98+
collapseOp.getReassociationIndices());
99+
rewriter.replaceOp(encodingOp, newCollapseOp);
100+
return success();
101+
}
102+
103+
void PropagateEncodingsPass::runOnOperation() {
104+
mlir::FunctionOpInterface funcOp = getOperation();
105+
MLIRContext *ctx = &getContext();
106+
RewritePatternSet propagationPatterns(ctx);
107+
propagationPatterns.insert<SwapEncodingOpWithTensorCollapseShapeOp>(ctx);
108+
GreedyRewriteConfig config;
109+
config.fold = true;
110+
config.cseConstants = false;
111+
if (failed(applyPatternsGreedily(funcOp, std::move(propagationPatterns),
112+
config))) {
113+
funcOp.emitOpError("failed to propagate encodings");
114+
return signalPassFailure();
115+
}
116+
}
117+
118+
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ iree_lit_test_suite(
4343
"pad_fusion_with_consumer.mlir",
4444
"pad_fusion_with_producer.mlir",
4545
"pipeline_tests.mlir",
46+
"propagate_encodings.mlir",
4647
"set_encoding.mlir",
4748
"set_encoding_pipeline.mlir",
4849
"sink_reshapes.mlir",

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ iree_lit_test_suite(
4141
"pad_fusion_with_consumer.mlir"
4242
"pad_fusion_with_producer.mlir"
4343
"pipeline_tests.mlir"
44+
"propagate_encodings.mlir"
4445
"set_encoding.mlir"
4546
"set_encoding_pipeline.mlir"
4647
"sink_reshapes.mlir"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-propagate-encodings))" --split-input-file %s | FileCheck %s
2+
3+
#encoding = #iree_encoding.matmul_k<k_dims = [1]>
4+
util.func public @propagate_encoding_through_collapse_shape(%src: tensor<2x4096x640xf16>) -> tensor<8192x640xf16, #encoding> {
5+
%collapsed = tensor.collapse_shape %src [[0, 1], [2]] : tensor<2x4096x640xf16> into tensor<8192x640xf16>
6+
%0 = iree_encoding.set_encoding %collapsed : tensor<8192x640xf16> -> tensor<8192x640xf16, #encoding>
7+
util.return %0 : tensor<8192x640xf16, #encoding>
8+
}
9+
// CHECK-DAG: #[[$ENCODING0:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
10+
// CHECK-DAG: #[[$ENCODING1:.+]] = #iree_encoding.matmul_k<k_dims = [2]>
11+
// CHECK-LABEL: @propagate_encoding_through_collapse_shape(
12+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
13+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[SRC]] : tensor<2x4096x640xf16> -> tensor<2x4096x640xf16, #[[$ENCODING1]]>
14+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SET_ENCODING]] {{\[}}[0, 1], [2]] : tensor<2x4096x640xf16, #[[$ENCODING1]]> into tensor<8192x640xf16, #[[$ENCODING0]]>
15+
// CHECK: util.return %[[COLLAPSED]]
16+
17+
// -----
18+
19+
#encoding = #iree_encoding.matmul_k<k_dims = [1]>
20+
util.func public @propagate_encoding_through_collapse_shape_chain(%src: tensor<2x4096x64x10xf16>) -> tensor<8192x640xf16, #encoding> {
21+
%collapsed = tensor.collapse_shape %src [[0], [1], [2, 3]] : tensor<2x4096x64x10xf16> into tensor<2x4096x640xf16>
22+
%collapsed_0 = tensor.collapse_shape %collapsed [[0, 1], [2]] : tensor<2x4096x640xf16> into tensor<8192x640xf16>
23+
%0 = iree_encoding.set_encoding %collapsed_0 : tensor<8192x640xf16> -> tensor<8192x640xf16, #encoding>
24+
util.return %0 : tensor<8192x640xf16, #encoding>
25+
}
26+
// CHECK-DAG: #[[$ENCODING0:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
27+
// CHECK-DAG: #[[$ENCODING1:.+]] = #iree_encoding.matmul_k<k_dims = [2]>
28+
// CHECK-LABEL: @propagate_encoding_through_collapse_shape_chain(
29+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
30+
// CHECK: %[[COLLAPSED_0:.+]] = tensor.collapse_shape %[[SRC]] {{\[}}[0], [1], [2, 3]] : tensor<2x4096x64x10xf16> into tensor<2x4096x640xf16>
31+
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[COLLAPSED_0]] : tensor<2x4096x640xf16> -> tensor<2x4096x640xf16, #[[$ENCODING1]]>
32+
// CHECK: %[[COLLAPSED_1:.+]] = tensor.collapse_shape %[[SET_ENCODING]] {{\[}}[0, 1], [2]] : tensor<2x4096x640xf16, #[[$ENCODING1]]> into tensor<8192x640xf16, #[[$ENCODING0]]>
33+
// CHECK: util.return %[[COLLAPSED_1]]

0 commit comments

Comments
 (0)