From 7899a205d25d86066f69eceaacf5df9ccfe22ed4 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 00:13:42 -0700 Subject: [PATCH 1/7] [MLIR][Transform] apply_registered_pass: support ListOptions as params Interpret the multiple values associated to a param as a comma-separated list, i.e. as the analog of a ListOption on a pass. --- .../lib/Dialect/Transform/IR/TransformOps.cpp | 28 +++++----- .../Transform/test-pass-application.mlir | 52 ++++++++++++------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 582d082153bef..bfe6416987629 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -791,6 +791,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, std::string options; llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. + auto appendValueAttr = [&](Attribute valueAttr) { + if (auto strAttr = dyn_cast(valueAttr)) + optionsStream << strAttr.getValue().str(); + else + valueAttr.print(optionsStream, /*elideType=*/true); + }; OperandRange dynamicOptions = getDynamicOptions(); for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) { @@ -799,7 +805,6 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, optionsStream << namedAttribute.getName().str(); // Append the key. optionsStream << "="; // And the key-value separator. - Attribute valueAttrToAppend; if (auto paramOperandIndex = dyn_cast(namedAttribute.getValue())) { // The corresponding value attribute is passed in via a param. @@ -810,22 +815,15 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, "should be the same as the number of options passed as params"); ArrayRef dynamicOption = state.getParams(dynamicOptions[dynamicOptionIdx]); - if (dynamicOption.size() != 1) - return emitSilenceableError() - << "options passed as a param must have " - "a single value associated, param " - << dynamicOptionIdx << " associates " << dynamicOption.size(); - valueAttrToAppend = dynamicOption[0]; + // Append all attributes associated to the param, separated by commas. + for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) { + if (idx > 0) + optionsStream << ","; + appendValueAttr(associatedAttr); + } } else { // Value is a static attribute. - valueAttrToAppend = namedAttribute.getValue(); - } - - // Append string representation of value attribute. - if (auto strAttr = dyn_cast(valueAttrToAppend)) { - optionsStream << strAttr.getValue().str(); - } else { - valueAttrToAppend.print(optionsStream, /*elideType=*/true); + appendValueAttr(namedAttribute.getValue()); } } optionsStream.flush(); diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 1d1be9eda3496..407dfa3823436 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -164,6 +164,38 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func private @valid_dynamic_pass_list_option() +module { + func.func @valid_dynamic_pass_list_option() { + return + } + + // CHECK: func @a() + func.func @a() { + return + } + // CHECK: func @b() + func.func @b() { + return + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op + %symbol_a = transform.param.constant "a" -> !transform.any_param + %symbol_b = transform.param.constant "b" -> !transform.any_param + %multiple_symbol_names = transform.merge_handles %symbol_a, %symbol_b : !transform.any_param + transform.apply_registered_pass "symbol-privatize" + with options = { exclude = %multiple_symbol_names } to %2 + : (!transform.any_op, !transform.any_param) -> !transform.any_op + transform.yield + } +} + +// ----- + func.func @invalid_options_as_str() { return } @@ -262,26 +294,6 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @too_many_pass_option_params() { - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op) { - %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %x = transform.param.constant true -> !transform.any_param - %y = transform.param.constant false -> !transform.any_param - %topdown_options = transform.merge_handles %x, %y : !transform.any_param - // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}} - transform.apply_registered_pass "canonicalize" - with options = { "top-down" = %topdown_options } to %1 - : (!transform.any_op, !transform.any_param) -> !transform.any_op - transform.yield - } -} - -// ----- - module attributes {transform.with_named_sequence} { // expected-error @below {{trying to schedule a pass on an unsupported operation}} // expected-note @below {{target op}} From f23b042aef030303e0e92f177f409cdec14d2aa0 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 00:45:10 -0700 Subject: [PATCH 2/7] Also support passing in an ArrayAttr as a ListOption, also through params --- .../lib/Dialect/Transform/IR/TransformOps.cpp | 12 +++- .../Transform/test-pass-application.mlir | 63 ++++++++++++++++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index bfe6416987629..0538faf5b3ba8 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -791,11 +791,17 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, std::string options; llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. - auto appendValueAttr = [&](Attribute valueAttr) { - if (auto strAttr = dyn_cast(valueAttr)) + std::function appendValueAttr = [&](Attribute valueAttr) { + if (auto arrayAttr = dyn_cast(valueAttr)) { + for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) { + appendValueAttr(eltAttr); + optionsStream << ","; + } + } else if (auto strAttr = dyn_cast(valueAttr)) { optionsStream << strAttr.getValue().str(); - else + } else { valueAttr.print(optionsStream, /*elideType=*/true); + } }; OperandRange dynamicOptions = getDynamicOptions(); diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 407dfa3823436..f7909f4c035d9 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func private @valid_dynamic_pass_list_option() +// CHECK-LABEL: func private @valid_multiple_params_as_list_option() module { - func.func @valid_dynamic_pass_list_option() { + func.func @valid_multiple_params_as_list_option() { return } @@ -196,6 +196,65 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func private @valid_array_attr_as_list_option() +module { + func.func @valid_array_attr_param_as_list_option() { + return + } + + // CHECK: func @a() + func.func @a() { + return + } + // CHECK: func @b() + func.func @b() { + return + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "symbol-privatize" + with options = { exclude = ["a", "b"] } to %2 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func private @valid_array_attr_param_as_list_option() +module { + func.func @valid_array_attr_param_as_list_option() { + return + } + + // CHECK: func @a() + func.func @a() { + return + } + // CHECK: func @b() + func.func @b() { + return + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op + %multiple_symbol_names = transform.param.constant ["a","b"] -> !transform.any_param + transform.apply_registered_pass "symbol-privatize" + with options = { exclude = %multiple_symbol_names } to %2 + : (!transform.any_op, !transform.any_param) -> !transform.any_op + transform.yield + } +} + +// ----- + func.func @invalid_options_as_str() { return } From c2a8659af0cf99fdf4d846e5b3c2d40c42eb9988 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 01:02:21 -0700 Subject: [PATCH 3/7] Minor clean-up --- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++------------- .../Transform/test-pass-application.mlir | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 0538faf5b3ba8..651462ee6ad03 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -792,16 +792,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, std::string options; llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. std::function appendValueAttr = [&](Attribute valueAttr) { - if (auto arrayAttr = dyn_cast(valueAttr)) { - for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) { - appendValueAttr(eltAttr); - optionsStream << ","; - } - } else if (auto strAttr = dyn_cast(valueAttr)) { + if (auto arrayAttr = dyn_cast(valueAttr)) + llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ","); + else if (auto strAttr = dyn_cast(valueAttr)) optionsStream << strAttr.getValue().str(); - } else { + else valueAttr.print(optionsStream, /*elideType=*/true); - } }; OperandRange dynamicOptions = getDynamicOptions(); @@ -822,11 +818,7 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, ArrayRef dynamicOption = state.getParams(dynamicOptions[dynamicOptionIdx]); // Append all attributes associated to the param, separated by commas. - for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) { - if (idx > 0) - optionsStream << ","; - appendValueAttr(associatedAttr); - } + llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ","); } else { // Value is a static attribute. appendValueAttr(namedAttribute.getValue()); diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index f7909f4c035d9..7262a8fe9faee 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -198,7 +198,7 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func private @valid_array_attr_as_list_option() module { - func.func @valid_array_attr_param_as_list_option() { + func.func @valid_array_attr_as_list_option() { return } From fb370bd473ddea31abb09ea59ae08c1a0a6782b9 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 07:17:32 -0700 Subject: [PATCH 4/7] Allow ArrayAttrs options to contain param-operands --- .../mlir/Dialect/Transform/IR/TransformOps.td | 2 +- .../lib/Dialect/Transform/IR/TransformOps.cpp | 164 ++++++++++++------ .../Transform/test-pass-application.mlir | 39 ++++- 3 files changed, 148 insertions(+), 57 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 0aa750e625436..140c9c66f3918 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -418,7 +418,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", with options = { "top-down" = false, "max-iterations" = %max_iter, "test-convergence" = true, - "max-num-rewrites" = %max_rewrites } + "max-num-rewrites" = %max_rewrites } to %module : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op ``` diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 651462ee6ad03..bb9bdd70625e4 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -788,42 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, // Obtain a single options-string to pass to the pass(-pipeline) from options // passed in as a dictionary of keys mapping to values which are either // attributes or param-operands pointing to attributes. + OperandRange dynamicOptions = getDynamicOptions(); std::string options; llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. - std::function appendValueAttr = [&](Attribute valueAttr) { - if (auto arrayAttr = dyn_cast(valueAttr)) - llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ","); - else if (auto strAttr = dyn_cast(valueAttr)) - optionsStream << strAttr.getValue().str(); - else - valueAttr.print(optionsStream, /*elideType=*/true); - }; - OperandRange dynamicOptions = getDynamicOptions(); - for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) { - if (idx > 0) - optionsStream << " "; // Interleave options separator. - optionsStream << namedAttribute.getName().str(); // Append the key. - optionsStream << "="; // And the key-value separator. - - if (auto paramOperandIndex = - dyn_cast(namedAttribute.getValue())) { - // The corresponding value attribute is passed in via a param. + // A helper to convert an option's attribute value into a corresponding + // string representation, with the ability to obtain the attr(s) from a param. + std::function appendValueAttr = [&](Attribute valueAttr) { + if (auto paramOperand = dyn_cast(valueAttr)) { + // The corresponding value attribute(s) is/are passed in via a param. // Obtain the param-operand via its specified index. - size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt(); + size_t dynamicOptionIdx = paramOperand.getIndex().getInt(); assert(dynamicOptionIdx < dynamicOptions.size() && - "number of dynamic option markers (UnitAttr) in options ArrayAttr " + "the number of ParamOperandAttrs in the options DictionaryAttr" "should be the same as the number of options passed as params"); - ArrayRef dynamicOption = + ArrayRef attrsAssociatedToParam = state.getParams(dynamicOptions[dynamicOptionIdx]); - // Append all attributes associated to the param, separated by commas. - llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ","); + // Recursive so as to append all attrs associated to the param. + llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr, + ","); + } else if (auto arrayAttr = dyn_cast(valueAttr)) { + // Recursive so as to append all nested attrs of the array. + llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ","); + } else if (auto strAttr = dyn_cast(valueAttr)) { + // Convert to unquoted string. + optionsStream << strAttr.getValue().str(); } else { - // Value is a static attribute. - appendValueAttr(namedAttribute.getValue()); + // For all other attributes, ask the attr to print itself (without type). + valueAttr.print(optionsStream, /*elideType=*/true); } - } + }; + + // Convert the options DictionaryAttr into a single string. + llvm::interleave( + getOptions(), optionsStream, + [&](auto namedAttribute) { + optionsStream << namedAttribute.getName().str(); // Append the key. + optionsStream << "="; // And the key-value separator. + appendValueAttr(namedAttribute.getValue()); // And the attr's str repr. + }, + " "); optionsStream.flush(); // Get pass or pass pipeline from registry. @@ -874,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions( SmallVectorImpl &dynamicOptions) { // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax. SmallVector keyValuePairs; - size_t dynamicOptionsIdx = 0; - auto parseKeyValuePair = [&]() -> ParseResult { - // Parse items of the form `key = value` where `key` is a bare identifier or - // a string and `value` is either an attribute or an operand. - std::string key; - Attribute valueAttr; - if (parser.parseOptionalKeywordOrString(&key)) - return parser.emitError(parser.getCurrentLocation()) - << "expected key to either be an identifier or a string"; - if (key.empty()) - return failure(); + // Helper for allowing parsing of option values which can be of the form: + // - a normal attribute + // - an operand (which would be converted to an attr referring to the operand) + // - ArrayAttrs containing the foregoing (in correspondence with ListOptions) + std::function parseValue = + [&](Attribute &valueAttr) -> ParseResult { + // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`: + if (succeeded(parser.parseOptionalLSquare())) { + SmallVector attrs; - if (parser.parseEqual()) - return parser.emitError(parser.getCurrentLocation()) - << "expected '=' after key in key-value pair"; + // Recursively parse the array's elements, which might be operands. + if (parser.parseCommaSeparatedList( + AsmParser::Delimiter::None, + [&]() -> ParseResult { return parseValue(attrs.emplace_back()); }, + " in options dictionary") || + parser.parseRSquare()) + return failure(); // NB: Attempted parse should've output error message. + + valueAttr = ArrayAttr::get(parser.getContext(), attrs); + + return success(); + } // Parse the value, which can be either an attribute or an operand. OptionalParseResult parsedValueAttr = @@ -899,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions( OpAsmParser::UnresolvedOperand operand; ParseResult parsedOperand = parser.parseOperand(operand); if (failed(parsedOperand)) - return parser.emitError(parser.getCurrentLocation()) - << "expected a valid attribute or operand as value associated " - << "to key '" << key << "'"; + return failure(); // NB: Attempted parse should've output error message. // To make use of the operand, we need to store it in the options dict. // As SSA-values cannot occur in attributes, what we do instead is store // an attribute in its place that contains the index of the param-operand, @@ -920,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions( << "in the generic print format"; } + return success(); + }; + + // Helper for `key = value`-pair parsing where `key` is a bare identifier or a + // string and `value` looks like either an attribute or an operand-in-an-attr. + std::function parseKeyValuePair = [&]() -> ParseResult { + std::string key; + Attribute valueAttr; + + if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty()) + return parser.emitError(parser.getCurrentLocation()) + << "expected key to either be an identifier or a string"; + + if (failed(parser.parseEqual())) + return parser.emitError(parser.getCurrentLocation()) + << "expected '=' after key in key-value pair"; + + if (failed(parseValue(valueAttr))) + return parser.emitError(parser.getCurrentLocation()) + << "expected a valid attribute or operand as value associated " + << "to key '" << key << "'"; + keyValuePairs.push_back(NamedAttribute(key, valueAttr)); + return success(); }; @@ -947,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, if (options.empty()) return; - printer << "{"; - llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) { - printer << namedAttribute.getName() << " = "; - Attribute value = namedAttribute.getValue(); - if (auto indexAttr = dyn_cast(value)) { + std::function printOptionValue = [&](Attribute valueAttr) { + if (auto paramOperandAttr = + dyn_cast(valueAttr)) { // Resolve index of param-operand to its actual SSA-value and print that. - printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]); + printer.printOperand( + dynamicOptions[paramOperandAttr.getIndex().getInt()]); + } else if (auto arrayAttr = dyn_cast(valueAttr)) { + // This case is so that ArrayAttr-contained operands are pretty-printed. + printer << "["; + llvm::interleaveComma(arrayAttr, printer, printOptionValue); + printer << "]"; } else { - printer.printAttribute(value); + printer.printAttribute(valueAttr); } + }; + + printer << "{"; + llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) { + printer << namedAttribute.getName(); + printer << " = "; + printOptionValue(namedAttribute.getValue()); }); printer << "}"; } @@ -966,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() { // and references to dynamic options in the options dictionary. auto dynamicOptions = SmallVector(getDynamicOptions()); - for (NamedAttribute namedAttr : getOptions()) - if (auto paramOperand = - dyn_cast(namedAttr.getValue())) { + + // Helper for option values to mark seen operands as having been seen (once). + std::function checkOptionValue = + [&](Attribute valueAttr) -> LogicalResult { + if (auto paramOperand = dyn_cast(valueAttr)) { size_t dynamicOptionIdx = paramOperand.getIndex().getInt(); if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size()) return emitOpError() @@ -979,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() { return emitOpError() << "dynamic option index " << dynamicOptionIdx << " is already used in options"; dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used. + } else if (auto arrayAttr = dyn_cast(valueAttr)) { + // Recurse into ArrayAttrs as they may contain references to operands. + for (auto eltAttr : arrayAttr) + if (failed(checkOptionValue(eltAttr))) + return failure(); } + return success(); + }; + + for (NamedAttribute namedAttr : getOptions()) + if (failed(checkOptionValue(namedAttr.getValue()))) + return failure(); + // All dynamicOptions-params seen in the dict will have been set to null. for (Value dynamicOption : dynamicOptions) if (dynamicOption) return emitOpError() << "a param operand does not have a corresponding " diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 7262a8fe9faee..e21e750011ce7 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func private @valid_multiple_params_as_list_option() +// CHECK-LABEL: func private @valid_multiple_values_as_list_option_single_param() module { - func.func @valid_multiple_params_as_list_option() { + func.func @valid_multiple_values_as_list_option_single_param() { return } @@ -253,6 +253,38 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: func private @valid_multiple_params_as_single_list_option() +module { + func.func @valid_multiple_params_as_single_list_option() { + return + } + + // CHECK: func @a() + func.func @a() { + return + } + // CHECK: func @b() + func.func @b() { + return + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op + %symbol_a = transform.param.constant "a" -> !transform.any_param + %symbol_b = transform.param.constant "b" -> !transform.any_param + transform.apply_registered_pass "symbol-privatize" + with options = { exclude = [%symbol_a, %symbol_b] } to %2 + : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op + transform.yield + } +} + + // ----- func.func @invalid_options_as_str() { @@ -294,7 +326,8 @@ func.func @invalid_options_due_to_reserved_attr() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}} + // expected-error @+3 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}} + // expected-error @+2 {{expected a valid attribute or operand as value associated to key 'top-down'}} %2 = transform.apply_registered_pass "canonicalize" with options = { "top-down" = #transform.param_operand } to %1 : (!transform.any_op) -> !transform.any_op transform.yield From 6620ca4fbfa6e1c60f20719e93bf2afcfffab684 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 07:59:32 -0700 Subject: [PATCH 5/7] Update Python-bindings --- .../mlir/dialects/transform/__init__.py | 41 ++++++------- .../Transform/test-pass-application.mlir | 1 - mlir/test/python/dialects/transform.py | 60 ++++++++++++++----- 3 files changed, 65 insertions(+), 37 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index bfe96b1b3e5d4..b075919d1ef0f 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -219,6 +219,11 @@ def __init__( super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) +OptionValueTypes = Union[ + Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool +] + + @_ods_cext.register_operation(_Dialect, replace=True) class ApplyRegisteredPassOp(ApplyRegisteredPassOp): def __init__( @@ -227,12 +232,7 @@ def __init__( target: Union[Operation, Value, OpView], pass_name: Union[str, StringAttr], *, - options: Optional[ - Dict[ - Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView, str, int, bool], - ] - ] = None, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, loc=None, ip=None, ): @@ -243,26 +243,32 @@ def __init__( context = (loc and loc.context) or Context.current cur_param_operand_idx = 0 - for key, value in options.items() if options is not None else {}: - if isinstance(key, StringAttr): - key = key.value + def option_value_to_attr(value): + nonlocal cur_param_operand_idx if isinstance(value, (Value, Operation, OpView)): dynamic_options.append(_get_op_result_or_value(value)) - options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context) cur_param_operand_idx += 1 + return ParamOperandAttr(cur_param_operand_idx - 1, context) elif isinstance(value, Attribute): - options_dict[key] = value + return value # The following cases auto-convert Python values to attributes. elif isinstance(value, bool): - options_dict[key] = BoolAttr.get(value) + return BoolAttr.get(value) elif isinstance(value, int): default_int_type = IntegerType.get_signless(64, context) - options_dict[key] = IntegerAttr.get(default_int_type, value) + return IntegerAttr.get(default_int_type, value) elif isinstance(value, str): - options_dict[key] = StringAttr.get(value) + return StringAttr.get(value) + elif isinstance(value, Sequence): + return ArrayAttr.get([option_value_to_attr(elt) for elt in value]) else: raise TypeError(f"Unsupported option type: {type(value)}") + + for key, value in options.items() if options is not None else {}: + if isinstance(key, StringAttr): + key = key.value + options_dict[key] = option_value_to_attr(value) super().__init__( result, _get_op_result_or_value(target), @@ -279,12 +285,7 @@ def apply_registered_pass( target: Union[Operation, Value, OpView], pass_name: Union[str, StringAttr], *, - options: Optional[ - Dict[ - Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView, str, int, bool], - ] - ] = None, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, loc=None, ip=None, ) -> Value: diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index e21e750011ce7..ce8f69c58701d 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -284,7 +284,6 @@ module attributes {transform.with_named_sequence} { } } - // ----- func.func @invalid_options_as_str() { diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index eeb95605d7a9a..aeadfcb596526 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -256,30 +256,45 @@ def testReplicateOp(module: Module): # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] +# CHECK-LABEL: TEST: testApplyRegisteredPassOp @run def testApplyRegisteredPassOp(module: Module): + # CHECK: transform.sequence sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op mod = transform.ApplyRegisteredPassOp( transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize" ) + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" + # CHECK-SAME: with options = {"top-down" = false} + # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op mod = transform.ApplyRegisteredPassOp( transform.AnyOpType.get(), mod.result, "canonicalize", options={"top-down": BoolAttr.get(False)}, ) + # CHECK: %[[MAX_ITER:.+]] = transform.param.constant max_iter = transform.param_constant( transform.AnyParamType.get(), IntegerAttr.get(IntegerType.get_signless(64), 10), ) + # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant max_rewrites = transform.param_constant( transform.AnyParamType.get(), IntegerAttr.get(IntegerType.get_signless(64), 1), ) - transform.apply_registered_pass( + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" + # NB: MLIR has sorted the dict lexicographically by key: + # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]], + # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]], + # CHECK-SAME: "test-convergence" = true, + # CHECK-SAME: "top-down" = false} + # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op + mod = transform.apply_registered_pass( transform.AnyOpType.get(), mod, "canonicalize", @@ -290,19 +305,32 @@ def testApplyRegisteredPassOp(module: Module): "max-rewrites": max_rewrites, }, ) + # CHECK: %{{.*}} = apply_registered_pass "symbol-privatize" + # CHECK-SAME: with options = {"exclude" = ["a", "b"]} + # CHECK-SAME: to %{{.*}} : (!transform.any_op) -> !transform.any_op + mod = transform.apply_registered_pass( + transform.AnyOpType.get(), + mod, + "symbol-privatize", + options={ "exclude": ("a", "b") }, + ) + # CHECK: %[[SYMBOL_A:.+]] = transform.param.constant + symbol_a = transform.param_constant( + transform.AnyParamType.get(), + StringAttr.get("a") + ) + # CHECK: %[[SYMBOL_B:.+]] = transform.param.constant + symbol_b = transform.param_constant( + transform.AnyParamType.get(), + StringAttr.get("b") + ) + # CHECK: %{{.*}} = apply_registered_pass "symbol-privatize" + # CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]} + # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op + mod = transform.apply_registered_pass( + transform.AnyOpType.get(), + mod, + "symbol-privatize", + options={ "exclude": (symbol_a, symbol_b) }, + ) transform.YieldOp() - # CHECK-LABEL: TEST: testApplyRegisteredPassOp - # CHECK: transform.sequence - # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op - # CHECK: %{{.*}} = apply_registered_pass "canonicalize" - # CHECK-SAME: with options = {"top-down" = false} - # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op - # CHECK: %[[MAX_ITER:.+]] = transform.param.constant - # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant - # CHECK: %{{.*}} = apply_registered_pass "canonicalize" - # NB: MLIR has sorted the dict lexicographically by key: - # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]], - # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]], - # CHECK-SAME: "test-convergence" = true, - # CHECK-SAME: "top-down" = false} - # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op From 1f161bad3e717633c2505c0e7cb9c34772d9451d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Fri, 13 Jun 2025 08:15:03 -0700 Subject: [PATCH 6/7] Fix Python formatting --- mlir/test/python/dialects/transform.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index aeadfcb596526..6c5e4e5505b1c 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -312,17 +312,15 @@ def testApplyRegisteredPassOp(module: Module): transform.AnyOpType.get(), mod, "symbol-privatize", - options={ "exclude": ("a", "b") }, + options={"exclude": ("a", "b")}, ) # CHECK: %[[SYMBOL_A:.+]] = transform.param.constant symbol_a = transform.param_constant( - transform.AnyParamType.get(), - StringAttr.get("a") + transform.AnyParamType.get(), StringAttr.get("a") ) # CHECK: %[[SYMBOL_B:.+]] = transform.param.constant symbol_b = transform.param_constant( - transform.AnyParamType.get(), - StringAttr.get("b") + transform.AnyParamType.get(), StringAttr.get("b") ) # CHECK: %{{.*}} = apply_registered_pass "symbol-privatize" # CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]} @@ -331,6 +329,6 @@ def testApplyRegisteredPassOp(module: Module): transform.AnyOpType.get(), mod, "symbol-privatize", - options={ "exclude": (symbol_a, symbol_b) }, + options={"exclude": (symbol_a, symbol_b)}, ) transform.YieldOp() From 83553bcb9bd5b3e0f5230980ef2ff6296a1fa558 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Jun 2025 03:46:20 -0700 Subject: [PATCH 7/7] Update docs --- mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 140c9c66f3918..62e66b3dabee8 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -423,6 +423,9 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op ``` + Options' values which are `ArrayAttr`s are converted to comma-separated + lists of options. Likewise for params which associate multiple values. + This op first looks for a pass pipeline with the specified name. If no such pipeline exists, it looks for a pass with the specified name. If no such pass exists either, this op fails definitely.