diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 062606e7e10b6..86233b0bc4f03 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = applySequenceBlock( callee.getBody().front(), getFailurePropagationMode(), state, results); + + if (!result.succeeded()) + return result; + mappings.clear(); detail::prepareValueMappings( mappings, callee.getBody().front().getTerminator()->getOperands(), state); diff --git a/mlir/test/Dialect/Transform/include-failure-propagation.mlir b/mlir/test/Dialect/Transform/include-failure-propagation.mlir new file mode 100644 index 0000000000000..94e9d8f27c233 --- /dev/null +++ b/mlir/test/Dialect/Transform/include-failure-propagation.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --verify-diagnostics + +module attributes { transform.with_named_sequence } { + // Callee returns a silenceable failure when given a module instead of func.func. + transform.named_sequence @callee(%root: !transform.any_op {transform.consumed}) -> (!transform.any_op) { + transform.test_consume_operand_of_op_kind_or_fail %root, "func.func" : !transform.any_op + transform.yield %root : !transform.any_op + } + + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %res = transform.sequence %root : !transform.any_op -> !transform.any_op failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // This include returns a silenceable failure; it must not remap results. + %included = transform.include @callee failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op) + transform.yield %included : !transform.any_op + } + + %count = transform.num_associations %res : (!transform.any_op) -> !transform.param + // expected-remark @below {{0}} + transform.debug.emit_param_as_remark %count : !transform.param + + // If the include incorrectly forwarded mappings on failure, this would run + // and produce an unexpected remark under --verify-diagnostics. + transform.foreach %res : !transform.any_op { + ^bb0(%it: !transform.any_op): + transform.debug.emit_remark_at %it, "include result unexpectedly populated" : !transform.any_op + } + transform.yield + } +} + +// ----- + +module { + func.func @payload() { + return + } +}