Skip to content

Commit 624a9fa

Browse files
authored
[Preprocessing] Fix bug in TD dag matching op (#19945)
Fixes a bug in the `transform.iree.match.cast_compatible_dag_from_root` op failing to match when there are repeated operands. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 21234ed commit 624a9fa

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,52 @@ module attributes {transform.with_named_sequence} {
136136
transform.yield
137137
}
138138
}
139+
140+
// -----
141+
142+
module attributes {transform.with_named_sequence} {
143+
144+
// CHECK: func.func @matmul_repeated_operand
145+
func.func @matmul_repeated_operand(%input: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> {
146+
// CHECK-NEXT: linalg.matmul_transpose_b
147+
// CHECK-SAME: match_status = "matched"
148+
%res = linalg.matmul_transpose_b {match_status = "unmatched"}
149+
ins(%input, %input : tensor<32x64xi8>, tensor<32x64xi8>)
150+
outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32>
151+
return %res : tensor<32x32xi32>
152+
}
153+
154+
// CHECK: func.func @matmul_non_repeated_operand
155+
func.func @matmul_non_repeated_operand(%input0: tensor<32x64xi8>, %input1: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> {
156+
// CHECK-NEXT: linalg.matmul_transpose_b
157+
// CHECK-SAME: match_status = "unmatched"
158+
%res = linalg.matmul_transpose_b {match_status = "unmatched"}
159+
ins(%input0, %input1 : tensor<32x64xi8>, tensor<32x64xi8>)
160+
outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32>
161+
return %res : tensor<32x32xi32>
162+
}
163+
164+
transform.named_sequence @match_matmul_repeated_operand(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
165+
%inputs, %outputs = transform.iree.match.cast_compatible_dag_from_root %arg0 {
166+
^bb0(%arg1: tensor<32x64xi8>, %arg2: tensor<32x32xi32>):
167+
%1 = linalg.matmul_transpose_b {match_status = "unmatched"}
168+
ins(%arg1, %arg1 : tensor<32x64xi8>, tensor<32x64xi8>)
169+
outs(%arg2 : tensor<32x32xi32>) -> tensor<32x32xi32>
170+
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
171+
transform.yield %arg0 : !transform.any_op
172+
}
173+
174+
transform.named_sequence @annotate(%generic: !transform.any_op {transform.readonly}) {
175+
%0 = transform.param.constant "matched" -> !transform.any_param
176+
transform.annotate %generic "match_status" = %0 : !transform.any_op, !transform.any_param
177+
transform.yield
178+
}
179+
180+
transform.named_sequence @__transform_main(%module: !transform.any_op) {
181+
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
182+
transform.foreach_match in %module
183+
@match_matmul_repeated_operand -> @annotate
184+
: (!transform.any_op) -> (!transform.any_op)
185+
transform.yield
186+
}
187+
}

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ IREE::transform_dialect::MatchCastCompatibleDagFromRootOp::matchOperation(
177177
return emitDefiniteFailure() << "Invalid block argument in target";
178178
}
179179
int64_t argIdx = targetBlockArg.getArgNumber();
180-
if (inputs[argIdx] && inputs[argIdx] != targetOperand) {
180+
if (inputs[argIdx] && inputs[argIdx] != payloadOperand) {
181181
return emitSilenceableError()
182182
<< "input operand with conflicting uses";
183183
}

0 commit comments

Comments
 (0)