Skip to content

Commit 6b5fecf

Browse files
authored
[mlir] transform dialect: don't crash in verifiers (#161098)
Fix crashes in the verifier of `transform.with_named_sequence` attribute attached to a symbol table operation caused by it constructing a call graph inside the symbol table. The call graph construction assumes calls and callables, such as functions or named sequences, have been verified, but it is not yet the case when the attribute verifier on the (parent) symbol table operation runs. Trigger such verification manually before constructing the call graph. This adds redundancy in verification, but there is currently no mechanism to change the order of verificaiton. In performance-critical scenarios, verification can be disabled altogether. Remove unnecessary verfificaton from `transform::IncludeOp::getEffects`. It was introduced along with the op definition as the op used to inspect the body of callee, which assumed the body existed, to identify handle consumption behavior. This was later evolved to having explicit argument attributes on the callee, which handles the absence of such attributes gracefully without the need for verification, but the verification was never removed. It would have been causing infinite recursion if kept in place. Fixes #159646. Fixes #159734. Fixes #159736.
1 parent 47d74ca commit 6b5fecf

File tree

3 files changed

+75
-13
lines changed

3 files changed

+75
-13
lines changed

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Transform/IR/Utils.h"
1414
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1515
#include "mlir/IR/DialectImplementation.h"
16+
#include "mlir/IR/Verifier.h"
1617
#include "llvm/ADT/SCCIterator.h"
1718
#include "llvm/ADT/TypeSwitch.h"
1819

@@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
140141
"operations with symbol tables";
141142
}
142143

144+
// Pre-verify calls and callables because call graph construction below
145+
// assumes they are valid, but this verifier runs before verifying the
146+
// nested operations.
147+
WalkResult walkResult = op->walk([](Operation *nested) {
148+
if (!isa<CallableOpInterface, CallOpInterface>(nested))
149+
return WalkResult::advance();
150+
151+
if (failed(verify(nested, /*verifyRecursively=*/false)))
152+
return WalkResult::interrupt();
153+
return WalkResult::advance();
154+
});
155+
if (walkResult.wasInterrupted())
156+
return failure();
157+
143158
const mlir::CallGraph callgraph(op);
144159
for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
145160
if (!scc.hasCycle())

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,17 +2097,11 @@ void transform::IncludeOp::getEffects(
20972097
getOperation(), getTarget());
20982098
if (!callee)
20992099
return defaultEffects();
2100-
DiagnosedSilenceableFailure earlyVerifierResult =
2101-
verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
2102-
if (!earlyVerifierResult.succeeded()) {
2103-
(void)earlyVerifierResult.silence();
2104-
return defaultEffects();
2105-
}
21062100

21072101
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
21082102
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
21092103
consumesHandle(getOperation()->getOpOperand(i), effects);
2110-
else
2104+
else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
21112105
onlyReadsHandle(getOperation()->getOpOperand(i), effects);
21122106
}
21132107
}

mlir/test/Dialect/Transform/ops-invalid.mlir

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,20 +369,21 @@ module attributes { transform.with_named_sequence } {
369369
// expected-error @below {{recursion not allowed in named sequences}}
370370
transform.named_sequence @self_recursion() -> () {
371371
transform.include @self_recursion failures(suppress) () : () -> ()
372+
transform.yield
372373
}
373374
}
374375

375376
// -----
376377

377378
module @mutual_recursion attributes { transform.with_named_sequence } {
378379
// expected-note @below {{operation on recursion stack}}
379-
transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
380+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
380381
transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> ()
381382
transform.yield
382383
}
383384

384385
// expected-error @below {{recursion not allowed in named sequences}}
385-
transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
386+
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () {
386387
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
387388
transform.yield
388389
}
@@ -430,7 +431,7 @@ module attributes { transform.with_named_sequence } {
430431
// -----
431432

432433
module attributes { transform.with_named_sequence } {
433-
transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
434+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
434435
transform.yield
435436
}
436437

@@ -444,7 +445,7 @@ module attributes { transform.with_named_sequence } {
444445
// -----
445446

446447
module attributes { transform.with_named_sequence } {
447-
transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
448+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
448449
transform.yield %arg0 : !transform.any_op
449450
}
450451

@@ -458,7 +459,7 @@ module attributes { transform.with_named_sequence } {
458459
// -----
459460

460461
module attributes { transform.with_named_sequence } {
461-
transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) {
462+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
462463
transform.yield %arg0 : !transform.any_op
463464
}
464465

@@ -543,14 +544,15 @@ module attributes { transform.with_named_sequence } {
543544
// -----
544545

545546
module attributes { transform.with_named_sequence } {
546-
// expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
547547
transform.named_sequence @foo(%op: !transform.any_op) {
548548
transform.debug.emit_remark_at %op, "message" : !transform.any_op
549549
transform.yield
550550
}
551551

552552
transform.sequence failures(propagate) {
553553
^bb0(%arg0: !transform.any_op):
554+
// expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
555+
// expected-note @below {{no effects specified for operand #0}}
554556
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
555557
transform.yield
556558
}
@@ -908,3 +910,54 @@ module attributes { transform.with_named_sequence } {
908910
transform.yield
909911
}
910912
}
913+
914+
// -----
915+
916+
module attributes { transform.with_named_sequence } {
917+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () {
918+
// Intentionally malformed func with no region. This shouldn't crash the
919+
// verifier of `with_named_sequence` that runs before we get to the
920+
// function.
921+
// expected-error @below {{requires one region}}
922+
"func.func"() : () -> ()
923+
transform.yield
924+
}
925+
}
926+
927+
// -----
928+
929+
module attributes { transform.with_named_sequence } {
930+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () {
931+
// Intentionally malformed call with a region. This shouldn't crash the
932+
// verifier of `with_named_sequence` that runs before we get to the call.
933+
// expected-error @below {{requires zero regions}}
934+
"func.call"() <{
935+
function_type = () -> (),
936+
sym_name = "lambda_function"
937+
}> ({
938+
^bb0:
939+
"func.return"() : () -> ()
940+
}) : () -> ()
941+
transform.yield
942+
}
943+
}
944+
945+
// -----
946+
947+
module attributes { transform.with_named_sequence } {
948+
// Intentionally malformed sequence where the verifier should not crash.
949+
// expected-error @below {{ op expects argument attribute array to have the same number of elements as the number of function arguments, got 1, but expected 3}}
950+
"transform.named_sequence"() <{
951+
arg_attrs = [{transform.readonly}],
952+
function_type = (i1, tensor<f32>, tensor<f32>) -> (),
953+
sym_name = "print_message"
954+
}> ({}) : () -> ()
955+
"transform.named_sequence"() <{
956+
function_type = (!transform.any_op) -> (),
957+
sym_name = "reference_other_module"
958+
}> ({
959+
^bb0(%arg0: !transform.any_op):
960+
"transform.include"(%arg0) <{target = @print_message}> : (!transform.any_op) -> ()
961+
"transform.yield"() : () -> ()
962+
}) : () -> ()
963+
}

0 commit comments

Comments
 (0)