Skip to content

Commit 2a123d7

Browse files
authored
[DT] Set encodings if iree.opt.data_tiling unit attribute is attached. (#21676)
The revision moves the filter to AnnotateDataTilingHints pass that adds the unit attribute to target gemms. If any operation has the data-tiling hint, it does nothing. It allows users to perform data-tiling selectively in preprocessing. E.g., user can run a transform dialect script that matches linalg ops and attaches the unit attribute. E.g., the below transform script only attaches `iree.opt.data_tiling` to the first matmul. ```mlir func.func @matmuls(%lhs: tensor<?x?xi8>, %rhs: tensor<?x?xi8>, %acc: tensor<?x?xi32>) -> (tensor<?x?xi32>, tensor<?x?xi32>) { %res0 = linalg.matmul_transpose_b ins(%lhs, %lhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32> %res1 = linalg.matmul_transpose_b ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32> return %res0, %res1 : tensor<?x?xi32>, tensor<?x?xi32> } module attributes {transform.with_named_sequence} { transform.named_sequence @match_matmul_repeated_operand(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { %inputs, %outputs = transform.iree.match.cast_compatible_dag_from_root %arg0 { ^bb0(%arg1: tensor<?x?xi8>, %arg2: tensor<?x?xi32>): %1 = linalg.matmul_transpose_b ins(%arg1, %arg1 : tensor<?x?xi8>, tensor<?x?xi8>) outs(%arg2 : tensor<?x?xi32>) -> tensor<?x?xi32> } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) transform.yield %arg0 : !transform.any_op } transform.named_sequence @Annotate(%generic: !transform.any_op {transform.readonly}) { transform.annotate %generic "iree.opt.data_tiling" : !transform.any_op transform.yield } transform.named_sequence @__transform_main(%module: !transform.any_op) { %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op transform.foreach_match in %module @match_matmul_repeated_operand -> @Annotate : (!transform.any_op) -> (!transform.any_op) transform.yield } } ``` Fixes #21246 --------- Signed-off-by: hanhanW <[email protected]>
1 parent 4afe274 commit 2a123d7

File tree

12 files changed

+221
-106
lines changed

12 files changed

+221
-106
lines changed

compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,23 @@
99

1010
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1111
#include "mlir/IR/Attributes.h"
12+
#include "mlir/IR/BuiltinAttributes.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314

1415
namespace mlir::iree_compiler::IREE::Encoding {
1516

17+
constexpr char kDataTilingHint[] = "iree.opt.data_tiling";
18+
19+
/// Returns true if the operation has data-tiling hint attribute.
20+
inline bool hasDataTilingHint(Operation *op) {
21+
return op->getAttr(kDataTilingHint) ? true : false;
22+
}
23+
24+
/// Adds an unit attribute with `kDataTilingHint` key to the operation.
25+
inline void setDataTilingHint(Operation *op) {
26+
op->setAttr(kDataTilingHint, UnitAttr::get(op->getContext()));
27+
}
28+
1629
/// Returns the encoding attribute from the type if there is an encoding that
1730
/// implements SerializableAttr. Otherwise, returns null.
1831
SerializableAttr getSerializableAttr(RankedTensorType type);
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
8+
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
9+
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
10+
#include "iree/compiler/DispatchCreation/Passes.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
13+
#include "mlir/Interfaces/FunctionInterfaces.h"
14+
#include "mlir/Support/LLVM.h"
15+
#include "mlir/Support/WalkResult.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
18+
#define DEBUG_TYPE "iree-dispatch-creation-annotate-data-tiling-hints"
19+
20+
namespace mlir::iree_compiler::DispatchCreation {
21+
#define GEN_PASS_DEF_ANNOTATEDATATILINGHINTSPASS
22+
#include "iree/compiler/DispatchCreation/Passes.h.inc"
23+
24+
namespace {
25+
struct AnnotateDataTilingHintsPass final
26+
: impl::AnnotateDataTilingHintsPassBase<AnnotateDataTilingHintsPass> {
27+
using Base::Base;
28+
void runOnOperation() override;
29+
};
30+
} // namespace
31+
32+
/// Returns true iff the linalgOp has a body like a regular matmul, i.e.
33+
/// yield(add(out, mul(cast(in0), cast(in1))))
34+
static bool hasMatmulLikeBody(linalg::LinalgOp linalgOp) {
35+
auto outBlockArg =
36+
linalgOp.getMatchingBlockArgument(linalgOp.getDpsInitOperand(0));
37+
auto yieldOp =
38+
dyn_cast<linalg::YieldOp>(outBlockArg.getParentBlock()->getTerminator());
39+
if (!yieldOp) {
40+
return false;
41+
}
42+
Operation *addOp = yieldOp->getOperand(0).getDefiningOp();
43+
if (!addOp || !isa<arith::AddIOp, arith::AddFOp>(addOp)) {
44+
return false;
45+
}
46+
Value addLhs = addOp->getOperand(0);
47+
Value addRhs = addOp->getOperand(1);
48+
Operation *addLhsOp = addLhs.getDefiningOp();
49+
Operation *addRhsOp = addRhs.getDefiningOp();
50+
if (!(addLhsOp && addRhs == outBlockArg) &&
51+
!(addRhsOp && addLhs == outBlockArg)) {
52+
return false;
53+
}
54+
Operation *mulOp = addLhsOp ? addLhsOp : addRhsOp;
55+
if (!isa<arith::MulFOp, arith::MulIOp>(mulOp)) {
56+
return false;
57+
}
58+
Value mulLhs = mulOp->getOperand(0);
59+
Value mulRhs = mulOp->getOperand(1);
60+
auto mulLhsOp = mulLhs.getDefiningOp<CastOpInterface>();
61+
auto mulRhsOp = mulRhs.getDefiningOp<CastOpInterface>();
62+
if (!isa<BlockArgument>(mulLhs) && !mulLhsOp && !isa<BlockArgument>(mulRhs) &&
63+
!mulRhsOp) {
64+
return false;
65+
}
66+
if ((mulLhsOp && !isa<BlockArgument>(mulLhsOp->getOperand(0))) ||
67+
(mulRhsOp && !isa<BlockArgument>(mulRhsOp->getOperand(0)))) {
68+
return false;
69+
}
70+
return true;
71+
}
72+
73+
/// Not all contractions are supported by data tiling, so return true if:
74+
/// 1) linalgOp has pure tensor semantics.
75+
/// 2) linalgOp does not have a preset compilation info.
76+
/// 3) The workgroup count is not present if linalgOp is wrapped within
77+
/// Flow::DispatchRegionOp.
78+
/// 4) All the operands do not have encodings.
79+
/// 5) linalgOp has contraction indexingMaps.
80+
/// 6) There are not more than one of each contraction dimension.
81+
/// 7) There is an M or N dimension, and there is a K dimension.
82+
/// 8) linalgOp has the same body as an ordinary int or float matmul.
83+
///
84+
/// These restrictions are required because data tiling currently creates
85+
/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
86+
///
87+
/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
88+
/// can support more cases.
89+
static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
90+
if (!linalgOp.hasPureTensorSemantics()) {
91+
return false;
92+
}
93+
if (getCompilationInfo(linalgOp)) {
94+
return false;
95+
}
96+
auto hasWorkgroupCounts = [](Operation *op) -> bool {
97+
auto parentDispatchOp = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
98+
return parentDispatchOp && !parentDispatchOp.getWorkgroupCount().empty();
99+
};
100+
if (hasWorkgroupCounts(linalgOp)) {
101+
return false;
102+
}
103+
auto hasEncoding = [](Value operand) -> bool {
104+
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
105+
return type && type.getEncoding();
106+
};
107+
if (llvm::any_of(linalgOp.getDpsInputs(), hasEncoding) ||
108+
llvm::any_of(linalgOp.getDpsInits(), hasEncoding)) {
109+
return false;
110+
}
111+
112+
if (!linalg::isaContractionOpInterface(linalgOp)) {
113+
return false;
114+
}
115+
auto cDims = linalg::inferContractionDims(linalgOp);
116+
if (failed(cDims) || cDims->batch.size() > 1 || cDims->m.size() > 1 ||
117+
cDims->n.size() > 1 || cDims->k.size() > 1) {
118+
return false;
119+
}
120+
if ((cDims->n.empty() && cDims->m.empty()) || cDims->k.empty()) {
121+
return false;
122+
}
123+
if (!hasMatmulLikeBody(linalgOp)) {
124+
return false;
125+
}
126+
return true;
127+
}
128+
129+
void AnnotateDataTilingHintsPass::runOnOperation() {
130+
FunctionOpInterface funcOp = getOperation();
131+
SmallVector<Operation *> candidates;
132+
WalkResult result = funcOp.walk([&](Operation *op) -> WalkResult {
133+
if (IREE::Encoding::hasDataTilingHint(op)) {
134+
return WalkResult::interrupt();
135+
}
136+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
137+
if (linalgOp && isSupportedContractionOp(linalgOp)) {
138+
candidates.push_back(op);
139+
return WalkResult::advance();
140+
}
141+
return WalkResult::advance();
142+
});
143+
if (result.wasInterrupted()) {
144+
return;
145+
}
146+
for (Operation *op : candidates) {
147+
IREE::Encoding::setDataTilingHint(op);
148+
}
149+
}
150+
151+
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package(
2020
iree_compiler_cc_library(
2121
name = "DispatchCreation",
2222
srcs = [
23+
"AnnotateDataTilingHints.cpp",
2324
"BitcastUnsupportedElementTypes.cpp",
2425
"BubbleUpExpandShapes.cpp",
2526
"CloneProducersIntoDispatchRegions.cpp",
@@ -57,6 +58,7 @@ iree_compiler_cc_library(
5758
":PassesIncGen",
5859
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
5960
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
61+
"//compiler/src/iree/compiler/Dialect/Encoding/Utils",
6062
"//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
6163
"//compiler/src/iree/compiler/Dialect/Flow/IR",
6264
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",

compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ iree_cc_library(
1717
"FusionUtils.h"
1818
"Passes.h"
1919
SRCS
20+
"AnnotateDataTilingHints.cpp"
2021
"BitcastUnsupportedElementTypes.cpp"
2122
"BubbleUpExpandShapes.cpp"
2223
"CloneProducersIntoDispatchRegions.cpp"
@@ -83,6 +84,7 @@ iree_cc_library(
8384
MLIRTransforms
8485
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
8586
iree::compiler::Dialect::Encoding::IR
87+
iree::compiler::Dialect::Encoding::Utils
8688
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
8789
iree::compiler::Dialect::Flow::IR
8890
iree::compiler::Dialect::Flow::Transforms

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
271271
options.cseConstants = false;
272272
return IREE::Flow::createCanonicalizePass(options);
273273
})
274+
.addPass(createAnnotateDataTilingHintsPass)
274275
// Set encodings on all eligible ops. All ops should be in compiler
275276
// formed dispatch regions, so encodings will be placed inside of the
276277
// dispatch regions with the data-tiled op.

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ def PropagateEncodingsPass :
326326
];
327327
}
328328

329+
def AnnotateDataTilingHintsPass :
330+
InterfacePass<"iree-dispatch-creation-annotate-data-tiling-hints", "mlir::FunctionOpInterface"> {
331+
let summary = "Adds data-tiling hint attribute to linalg operations.";
332+
let description = [{
333+
The pass does nothing, if any operation already has the data-tiling hint
334+
attribute. Otherwise, it aggressively filters linalg ops and adds the
335+
data-tiling hint attribute to the operations.
336+
}];
337+
}
338+
329339
def SetEncodingPass : InterfacePass<"iree-dispatch-creation-set-encoding",
330340
"mlir::FunctionOpInterface"> {
331341
let summary = "Introduces tensor encoding for flow dispatch regions.";

compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp

Lines changed: 2 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
88
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
99
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
10+
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
1011
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
1112
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1213
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
@@ -73,108 +74,11 @@ static Type getContractionInputTypeWithSignedness(OpBuilder &builder,
7374
return elemType;
7475
}
7576

76-
/// Returns true iff the linalgOp has a body like a regular matmul, i.e.
77-
/// yield(add(out, mul(cast(in0), cast(in1))))
78-
static bool hasMatmulLikeBody(linalg::LinalgOp linalgOp) {
79-
auto outBlockArg =
80-
linalgOp.getMatchingBlockArgument(linalgOp.getDpsInitOperand(0));
81-
auto yieldOp =
82-
dyn_cast<linalg::YieldOp>(outBlockArg.getParentBlock()->getTerminator());
83-
if (!yieldOp) {
84-
return false;
85-
}
86-
Operation *addOp = yieldOp->getOperand(0).getDefiningOp();
87-
if (!addOp || !isa<arith::AddIOp, arith::AddFOp>(addOp)) {
88-
return false;
89-
}
90-
Value addLhs = addOp->getOperand(0);
91-
Value addRhs = addOp->getOperand(1);
92-
Operation *addLhsOp = addLhs.getDefiningOp();
93-
Operation *addRhsOp = addRhs.getDefiningOp();
94-
if (!(addLhsOp && addRhs == outBlockArg) &&
95-
!(addRhsOp && addLhs == outBlockArg)) {
96-
return false;
97-
}
98-
Operation *mulOp = addLhsOp ? addLhsOp : addRhsOp;
99-
if (!isa<arith::MulFOp, arith::MulIOp>(mulOp)) {
100-
return false;
101-
}
102-
Value mulLhs = mulOp->getOperand(0);
103-
Value mulRhs = mulOp->getOperand(1);
104-
auto mulLhsOp = mulLhs.getDefiningOp<CastOpInterface>();
105-
auto mulRhsOp = mulRhs.getDefiningOp<CastOpInterface>();
106-
if (!isa<BlockArgument>(mulLhs) && !mulLhsOp && !isa<BlockArgument>(mulRhs) &&
107-
!mulRhsOp) {
108-
return false;
109-
}
110-
if ((mulLhsOp && !isa<BlockArgument>(mulLhsOp->getOperand(0))) ||
111-
(mulRhsOp && !isa<BlockArgument>(mulRhsOp->getOperand(0)))) {
112-
return false;
113-
}
114-
return true;
115-
}
116-
117-
/// Not all contractions are supported by data tiling, so return true if:
118-
/// 1) linalgOp has pure tensor semantics.
119-
/// 2) linalgOp does not have a preset compilation info.
120-
/// 3) The workgroup count is not present if linalgOp is wrapped within
121-
/// Flow::DispatchRegionOp.
122-
/// 4) All the operands do not have encodings.
123-
/// 5) linalgOp has contraction indexingMaps.
124-
/// 6) There are not more than one of each contraction dimension.
125-
/// 7) There is an M or N dimension, and there is a K dimension.
126-
/// 8) linalgOp has the same body as an ordinary int or float matmul.
127-
///
128-
/// These restrictions are required because data tiling currently creates
129-
/// an Mmt4DOp or BatchMmt4DOp on the packed inputs.
130-
///
131-
/// TODO(#16176): Loosen restrictions on contraction ops once data tiling
132-
/// can support more cases.
133-
static bool isSupportedContractionOp(linalg::LinalgOp linalgOp) {
134-
if (!linalgOp.hasPureTensorSemantics()) {
135-
return false;
136-
}
137-
if (getCompilationInfo(linalgOp)) {
138-
return false;
139-
}
140-
auto hasWorkgroupCounts = [](Operation *op) -> bool {
141-
auto parentDispatchOp = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
142-
return parentDispatchOp && !parentDispatchOp.getWorkgroupCount().empty();
143-
};
144-
if (hasWorkgroupCounts(linalgOp)) {
145-
return false;
146-
}
147-
auto hasEncoding = [](Value operand) -> bool {
148-
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
149-
return type && type.getEncoding();
150-
};
151-
if (llvm::any_of(linalgOp.getDpsInputs(), hasEncoding) ||
152-
llvm::any_of(linalgOp.getDpsInits(), hasEncoding)) {
153-
return false;
154-
}
155-
156-
if (!linalg::isaContractionOpInterface(linalgOp)) {
157-
return false;
158-
}
159-
auto cDims = linalg::inferContractionDims(linalgOp);
160-
if (failed(cDims) || cDims->batch.size() > 1 || cDims->m.size() > 1 ||
161-
cDims->n.size() > 1 || cDims->k.size() > 1) {
162-
return false;
163-
}
164-
if ((cDims->n.empty() && cDims->m.empty()) || cDims->k.empty()) {
165-
return false;
166-
}
167-
if (!hasMatmulLikeBody(linalgOp)) {
168-
return false;
169-
}
170-
return true;
171-
}
172-
17377
static SmallVector<linalg::LinalgOp>
17478
getDataTilingCandidates(FunctionOpInterface funcOp) {
17579
SmallVector<linalg::LinalgOp> result;
17680
funcOp.walk([&](linalg::LinalgOp op) {
177-
if (!isSupportedContractionOp(op)) {
81+
if (!IREE::Encoding::hasDataTilingHint(op)) {
17882
return;
17983
}
18084
result.push_back(op);

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
name = "lit",
1717
srcs = enforce_glob(
1818
[
19+
"annotate_data_tiling_hints.mlir",
1920
"bitcast_unsupported_element_types.mlir",
2021
"clone_producers_into_dispatch_regions.mlir",
2122
"collapse_dimensions.mlir",

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ iree_lit_test_suite(
1414
NAME
1515
lit
1616
SRCS
17+
"annotate_data_tiling_hints.mlir"
1718
"bitcast_unsupported_element_types.mlir"
1819
"bubble_up_expand_shapes.mlir"
1920
"bubble_up_extract_slice.mlir"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-annotate-data-tiling-hints))" --split-input-file %s | FileCheck %s
2+
3+
util.func public @matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
4+
%0 = linalg.matmul
5+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
6+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
7+
util.return %0 : tensor<?x?xf32>
8+
}
9+
// CHECK-LABEL: @matmul(
10+
// CHECK: linalg.matmul
11+
// CHECK-SAME: iree.opt.data_tiling
12+
13+
// -----
14+
15+
util.func public @matmul_with_preset_hints(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
16+
%0 = linalg.matmul {"iree.opt.data_tiling"}
17+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
18+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
19+
%1 = linalg.matmul
20+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
21+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
22+
util.return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
23+
}
24+
// CHECK-LABEL: @matmul_with_preset_hints(
25+
// CHECK: linalg.matmul
26+
// CHECK-SAME: iree.opt.data_tiling
27+
// CHECK-NOT: iree.opt.data_tiling

0 commit comments

Comments
 (0)