Skip to content

Commit 288529e

Browse files
[mlir][transform] Clean up SplitHandlesOp
* Rename to `SplitHandleOp`: it splits a single handle. * Drop `num_result_handles` attribute: it is redundant and can be inferred from the number of results. * Improve documentation and minor code cleanups. Differential Revision: https://reviews.llvm.org/D149937
1 parent 1b9d0de commit 288529e

File tree

3 files changed

+50
-54
lines changed

3 files changed

+50
-54
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -509,35 +509,36 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
509509
}];
510510
}
511511

512-
def SplitHandlesOp : TransformDialectOp<"split_handles",
512+
def SplitHandleOp : TransformDialectOp<"split_handle",
513513
[FunctionalStyleTransformOpTrait,
514514
DeclareOpInterfaceMethods<TransformOpInterface>,
515515
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
516-
let summary = "Splits handles from a union of payload ops to a list";
516+
let summary = "Splits a handle of payload ops into handles with a single op";
517517
let description = [{
518-
Creates `num_result_handles` transform IR handles extracted from the
519-
`handle` operand. The resulting Payload IR operation handles are listed
520-
in the same order as the operations appear in the source `handle`.
521-
This is useful for ensuring a statically known number of operations are
522-
tracked by the source `handle` and to extract them into individual handles
523-
that can be further manipulated in isolation.
524-
525-
This operation succeeds and returns `num_result_handles` if the statically
526-
specified `num_result_handles` corresponds to the dynamic number of
527-
operations contained in the source `handle`. Otherwise it silently fails.
518+
Splits `handle` into one or multiple handles, as specified by the number
519+
of results of this operation. `handle` should be mapped to as many payload
520+
ops as there are results. Otherwise, this transform will fail silently.
521+
Each result handle is mapped to exactly one payload op. The order
522+
of the payload ops is preserved, i.e., the i-th payload op is mapped to the
523+
i-th result handle.
524+
525+
This operation is useful for ensuring a statically known number of
526+
operations are tracked by the source `handle` and to extract them into
527+
individual handles that can be further manipulated in isolation.
528+
529+
If `handle` is empty, this transform will succeed and all result handles
530+
are empty.
528531
}];
529532

530-
let arguments = (ins TransformHandleTypeInterface:$handle,
531-
I64Attr:$num_result_handles);
533+
let arguments = (ins TransformHandleTypeInterface:$handle);
532534
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
533535

534536
let builders = [
535537
OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
536538
];
537539

538540
let assemblyFormat = [{
539-
$handle `in` `[` $num_result_handles `]`
540-
attr-dict `:` functional-type(operands, results)
541+
$handle attr-dict `:` functional-type(operands, results)
541542
}];
542543
}
543544

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,48 +1488,43 @@ LogicalResult transform::NamedSequenceOp::verify() {
14881488
}
14891489

14901490
//===----------------------------------------------------------------------===//
1491-
// SplitHandlesOp
1491+
// SplitHandleOp
14921492
//===----------------------------------------------------------------------===//
14931493

1494-
void transform::SplitHandlesOp::build(OpBuilder &builder,
1495-
OperationState &result, Value target,
1496-
int64_t numResultHandles) {
1494+
void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
1495+
Value target, int64_t numResultHandles) {
14971496
result.addOperands(target);
1498-
result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
1499-
builder.getI64IntegerAttr(numResultHandles));
15001497
auto pdlOpType = pdl::OperationType::get(builder.getContext());
15011498
result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
15021499
}
15031500

15041501
DiagnosedSilenceableFailure
1505-
transform::SplitHandlesOp::apply(transform::TransformResults &results,
1506-
transform::TransformState &state) {
1507-
int64_t numResultHandles =
1508-
getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
1509-
int64_t expectedNumResultHandles = getNumResultHandles();
1510-
if (numResultHandles != expectedNumResultHandles) {
1511-
// Empty input handle corner case: always propagates empty handles in both
1512-
// suppress and propagate modes.
1513-
if (numResultHandles == 0) {
1514-
for (OpResult result : getResults())
1515-
results.set(result, {});
1516-
return DiagnosedSilenceableFailure::success();
1517-
}
1502+
transform::SplitHandleOp::apply(transform::TransformResults &results,
1503+
transform::TransformState &state) {
1504+
int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
15181505

1519-
// If the input handle was not empty and the number of result handles does
1520-
// not match, this is a legit silenceable error.
1521-
return emitSilenceableError()
1522-
<< getHandle() << " expected to contain " << expectedNumResultHandles
1523-
<< " operation handles but it contains " << numResultHandles
1524-
<< " handles";
1506+
// Empty handle corner case: all result handles are empty.
1507+
if (numPayloadOps == 0) {
1508+
for (OpResult result : getResults())
1509+
results.set(result, {});
1510+
return DiagnosedSilenceableFailure::success();
15251511
}
1526-
// Normal successful case.
1512+
1513+
// If the input handle was not empty and the number of payload ops does not
1514+
// match, this is a legit silenceable error.
1515+
if (numPayloadOps != getNumResults())
1516+
return emitSilenceableError()
1517+
<< getHandle() << " expected to contain " << getNumResults()
1518+
<< " payload ops but it contains " << numPayloadOps
1519+
<< " payload ops";
1520+
15271521
for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
15281522
results.set(getResults()[en.index()].cast<OpResult>(), en.value());
1523+
15291524
return DiagnosedSilenceableFailure::success();
15301525
}
15311526

1532-
void transform::SplitHandlesOp::getEffects(
1527+
void transform::SplitHandleOp::getEffects(
15331528
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
15341529
onlyReadsHandle(getHandle(), effects);
15351530
producesHandle(getResults(), effects);

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ transform.sequence failures(propagate) {
818818

819819
// -----
820820

821-
func.func @split_handles(%a: index, %b: index, %c: index) {
821+
func.func @split_handle(%a: index, %b: index, %c: index) {
822822
%0 = arith.muli %a, %b : index
823823
%1 = arith.muli %a, %c : index
824824
return
@@ -827,17 +827,17 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
827827
transform.sequence failures(propagate) {
828828
^bb1(%fun: !pdl.operation):
829829
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
830-
%h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
830+
%h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
831831
// expected-remark @below {{1}}
832832
transform.test_print_number_of_associated_payload_ir_ops %h#0
833833
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
834-
// expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
835-
%h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
834+
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
835+
%h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
836836
}
837837

838838
// -----
839839

840-
func.func @split_handles(%a: index, %b: index, %c: index) {
840+
func.func @split_handle(%a: index, %b: index, %c: index) {
841841
%0 = arith.muli %a, %b : index
842842
%1 = arith.muli %a, %c : index
843843
return
@@ -846,12 +846,12 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
846846
transform.sequence failures(suppress) {
847847
^bb1(%fun: !pdl.operation):
848848
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
849-
%h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
849+
%h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
850850
// expected-remark @below {{1}}
851851
transform.test_print_number_of_associated_payload_ir_ops %h#0
852852
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
853853
// Silenceable failure and all handles are now empty.
854-
%h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
854+
%h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
855855
// expected-remark @below {{0}}
856856
transform.test_print_number_of_associated_payload_ir_ops %h_2#0
857857
}
@@ -966,7 +966,7 @@ transform.with_pdl_patterns {
966966

967967
// -----
968968

969-
func.func @split_handles(%a: index, %b: index, %c: index) {
969+
func.func @split_handle(%a: index, %b: index, %c: index) {
970970
%0 = arith.muli %a, %b : index
971971
%1 = arith.muli %a, %c : index
972972
return
@@ -975,8 +975,8 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
975975
transform.sequence -> !pdl.operation failures(propagate) {
976976
^bb1(%fun: !pdl.operation):
977977
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
978-
// expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
979-
%h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
978+
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
979+
%h_2:3 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
980980
/// Test that yield does not crash in the presence of silenceable error in
981981
/// propagate mode.
982982
yield %fun : !pdl.operation
@@ -988,7 +988,7 @@ transform.sequence -> !transform.any_op failures(suppress) {
988988
^bb0(%arg0: !transform.any_op):
989989
%muli = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op
990990
// Edge case propagating empty handles in splitting.
991-
%0:3 = split_handles %muli in [3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
991+
%0:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
992992
// Test does not crash when accessing the empty handle.
993993
yield %0#0 : !transform.any_op
994994
}

0 commit comments

Comments
 (0)