diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b946fc8875860..1eebddca3df4d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1062,36 +1062,37 @@ def SplitHandleOp : TransformDialectOp<"split_handle", [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { - let summary = "Splits a handle of payload ops into handles with a single op"; + let summary = "Splits a handle or parameter into multiple values"; let description = [{ Splits `handle` into one or multiple handles, as specified by the number of results of this operation. `handle` should be mapped to as many payload - ops as there are results. Otherwise, this transform will fail produces a - silenceable failure by default. Each result handle is mapped to exactly one - payload op. The order of the payload ops is preserved, i.e., the i-th - payload op is mapped to the i-th result handle. + ops, values or parameteres as there are results. Otherwise, this transform + will fail producing a silenceable failure by default. Each result handle + is mapped to exactly one payload unless specified otherwise by attributes + described below. The order of the payloads is preserved, i.e., the i-th + payload is mapped to the i-th result handle. This operation is useful for ensuring a statically known number of - operations are tracked by the source `handle` and to extract them into + payloads are tracked by the source `handle` and to extract them into individual handles that can be further manipulated in isolation. - If there are more payload ops than results, the remaining ops are mapped to + If there are more payloads than results, the remaining payloads are mapped to the result with index `overflow_result`. If no `overflow_result` is specified, the transform produces a silenceable failure. If there are fewer payload ops than results, the transform produces a silenceable failure if `fail_on_payload_too_small` is set to "true". Otherwise, it succeeds and the remaining result handles are not mapped to - any op. It also succeeds if `handle` is empty and + anything. It also succeeds if `handle` is empty and `pass_through_empty_handle` is set to "true", regardless of `fail_on_payload_too_small`. }]; - let arguments = (ins TransformHandleTypeInterface:$handle, + let arguments = (ins Transform_AnyHandleOrParamType:$handle, DefaultValuedAttr:$pass_through_empty_handle, DefaultValuedAttr:$fail_on_payload_too_small, OptionalAttr:$overflow_result); - let results = (outs Variadic:$results); + let results = (outs Variadic:$results); let hasVerifier = 1; let builders = [ diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 590cae9aa0d66..1f0f183e29f9a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2415,32 +2415,62 @@ DiagnosedSilenceableFailure transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); + int64_t numPayloads = + llvm::TypeSwitch(getHandle().getType()) + .Case([&](auto x) { + return llvm::range_size(state.getPayloadOps(getHandle())); + }) + .Case([&](auto x) { + return llvm::range_size(state.getPayloadValues(getHandle())); + }) + .Case([&](auto x) { + return llvm::range_size(state.getParams(getHandle())); + }) + .Default([](auto x) { + llvm_unreachable("unknown transform dialect type interface"); + return -1; + }); + auto produceNumOpsError = [&]() { return emitSilenceableError() << getHandle() << " expected to contain " << this->getNumResults() - << " payload ops but it contains " << numPayloadOps - << " payload ops"; + << " payloads but it contains " << numPayloads << " payloads"; }; // Fail if there are more payload ops than results and no overflow result was // specified. - if (numPayloadOps > getNumResults() && !getOverflowResult().has_value()) + if (numPayloads > getNumResults() && !getOverflowResult().has_value()) return produceNumOpsError(); // Fail if there are more results than payload ops. Unless: // - "fail_on_payload_too_small" is set to "false", or // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. - if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() && - (numPayloadOps != 0 || !getPassThroughEmptyHandle())) + if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() && + (numPayloads != 0 || !getPassThroughEmptyHandle())) return produceNumOpsError(); - // Distribute payload ops. - SmallVector> resultHandles(getNumResults(), {}); + // Distribute payloads. + SmallVector> resultHandles(getNumResults(), {}); if (getOverflowResult()) - resultHandles[*getOverflowResult()].reserve(numPayloadOps - - getNumResults()); - for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) { + resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults()); + + auto container = [&]() { + if (isa(getHandle().getType())) { + return llvm::map_to_vector( + state.getPayloadOps(getHandle()), + [](Operation *op) -> MappedValue { return op; }); + } + if (isa(getHandle().getType())) { + return llvm::map_to_vector(state.getPayloadValues(getHandle()), + [](Value v) -> MappedValue { return v; }); + } + assert(isa(getHandle().getType()) && + "unsupported kind of transform dialect type"); + return llvm::map_to_vector(state.getParams(getHandle()), + [](Attribute a) -> MappedValue { return a; }); + }(); + + for (auto &&en : llvm::enumerate(container)) { int64_t resultNum = en.index(); if (resultNum >= getNumResults()) resultNum = *getOverflowResult(); @@ -2449,7 +2479,8 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, // Set transform op results. for (auto &&it : llvm::enumerate(resultHandles)) - results.set(llvm::cast(getResult(it.index())), it.value()); + results.setMappedValues(llvm::cast(getResult(it.index())), + it.value()); return DiagnosedSilenceableFailure::success(); } @@ -2466,6 +2497,15 @@ LogicalResult transform::SplitHandleOp::verify() { if (getOverflowResult().has_value() && !(*getOverflowResult() < getNumResults())) return emitOpError("overflow_result is not a valid result index"); + + for (Type resultType : getResultTypes()) { + if (implementSameTransformInterface(getHandle().getType(), resultType)) + continue; + + return emitOpError("expects result types to implement the same transform " + "interface as the operand type"); + } + return success(); } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 4fe2dbedff56e..ecc234587cda9 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1094,7 +1094,7 @@ module attributes {transform.with_named_sequence} { // expected-remark @below {{1}} transform.debug.emit_param_as_remark %p : !transform.param %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op - // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}} + // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}} %h_2:3 = transform.split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -1180,6 +1180,71 @@ module attributes {transform.with_named_sequence} { // ----- +func.func private @opaque() -> (i32, i32) + +func.func @split_handle() { + func.call @opaque() : () -> (i32, i32) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%fun: !transform.any_op) { + %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op + %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value + %p = transform.num_associations %val : (!transform.any_value) -> !transform.any_param + // expected-remark @below {{total 2}} + transform.debug.emit_param_as_remark %p, "total" : !transform.any_param + %h:2 = transform.split_handle %val : (!transform.any_value) -> (!transform.any_value, !transform.any_value) + %p1 = transform.num_associations %h#0 : (!transform.any_value) -> !transform.any_param + %p2 = transform.num_associations %h#1 : (!transform.any_value) -> !transform.any_param + // expected-remark @below {{first 1}} + transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param + // expected-remark @below {{second 1}} + transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param + transform.yield + } +} + +// ----- + +func.func private @opaque() -> (i32, i32) + +func.func @split_handle() { + func.call @opaque() : () -> (i32, i32) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%fun: !transform.any_op) { + %op = transform.structured.match ops{["func.call"]} in %fun : (!transform.any_op) -> !transform.any_op + %val = transform.get_result %op[all] : (!transform.any_op) -> !transform.any_value + %type = transform.get_type %val : (!transform.any_value) -> !transform.any_param + %p = transform.num_associations %type : (!transform.any_param) -> !transform.any_param + // expected-remark @below {{total 2}} + transform.debug.emit_param_as_remark %p, "total" : !transform.any_param + %h:2 = transform.split_handle %type : (!transform.any_param) -> (!transform.any_param, !transform.any_param) + %p1 = transform.num_associations %h#0 : (!transform.any_param) -> !transform.any_param + %p2 = transform.num_associations %h#1 : (!transform.any_param) -> !transform.any_param + // expected-remark @below {{first 1}} + transform.debug.emit_param_as_remark %p1, "first" : !transform.any_param + // expected-remark @below {{second 1}} + transform.debug.emit_param_as_remark %p1, "second" : !transform.any_param + transform.yield + } +} + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%fun: !transform.any_op) { + // expected-error @below {{op expects result types to implement the same transform interface as the operand type}} + transform.split_handle %fun : (!transform.any_op) -> (!transform.any_op, !transform.any_value) + transform.yield + } +} + +// ----- + "test.some_op"() : () -> () "other_dialect.other_op"() : () -> () @@ -1324,7 +1389,7 @@ module attributes {transform.with_named_sequence} { transform.sequence %root : !transform.any_op -> !transform.any_op failures(propagate) { ^bb1(%fun: !transform.any_op): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op - // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}} + // expected-error @below {{expected to contain 3 payloads but it contains 2 payloads}} %h_2:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) /// Test that yield does not crash in the presence of silenceable error in /// propagate mode.