diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt index df5af7ae710da..9acab9228f100 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt @@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls) mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTransformDialectEnumIncGen) add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen) +mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRTransformDialectAttributesIncGen) +add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen) add_mlir_dialect(TransformOps transform) add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h index 3cb935003b4c4..379af932ca484 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h @@ -17,4 +17,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc" + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td index ebad2994880e7..e67a9444c24a8 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td @@ -10,6 +10,14 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" + +class Transform_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = attrMnemonic; +} def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">; def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">; @@ -33,4 +41,15 @@ def MatchCmpIPredicateAttr : I32EnumAttr< let cppNamespace = "::mlir::transform"; } +def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> { + let description = [{ + Used to refer to a specific param-operand (via its index) from within an + attribute on a transform operation. + }]; + let parameters = (ins + "IntegerAttr":$index + ); + let assemblyFormat = "`<` `index` `=` $index `>`"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index d03049e186f94..c7ea5ade72ace 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -19,6 +19,7 @@ def Transform_Dialect : Dialect { let cppNamespace = "::mlir::transform"; let hasOperationAttrVerify = 1; + let useDefaultAttributePrinterParser = 1; let extraClassDeclaration = [{ /// Symbol name for the default entry point "named sequence". constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index e864a65f8ceac..f75ba27e58e76 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", let description = [{ This transform applies the specified pass or pass pipeline to the targeted ops. The name of the pass/pipeline is specified as a string attribute, as - set during pass/pipeline registration. Optionally, pass options may be - specified as (space-separated) string attributes with the option to pass - these attributes via params. The pass options syntax is identical to the one - used with "mlir-opt". + set during pass/pipeline registration. + + Optionally, pass options may be specified via a DictionaryAttr. This + dictionary is converted to a string -- formatted `key=value ...` -- which + is expected to be in the exact format used by the pass on the commandline. + Values are either attributes or (SSA-values of) Transform Dialect params. + For example: + + ```mlir + transform.apply_registered_pass "canonicalize" + with options = { "top-down" = false, + "max-iterations" = %max_iter, + "test-convergence" = true, + "max-num-rewrites" = %max_rewrites } + to %module + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + ``` This op first looks for a pass pipeline with the specified name. If no such pipeline exists, it looks for a pass with the specified name. If no such @@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", }]; let arguments = (ins StrAttr:$pass_name, - DefaultValuedAttr:$options, + DefaultValuedAttr:$options, Variadic:$dynamic_options, TransformHandleTypeInterface:$target); let results = (outs TransformHandleTypeInterface:$result); diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index 497ceb19f1a21..4a95fe7459e8c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -8,17 +8,22 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Analysis/CallGraph.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc" + #ifndef NDEBUG void transform::detail::checkImplementsTransformOpInterface( StringRef name, MLIRContext *context) { @@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() { #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); initializeTypes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc" + >(); initializeLibraryModule(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index a0f9518e3d12f..582d082153bef 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -54,10 +54,11 @@ using namespace mlir; static ParseResult parseApplyRegisteredPassOptions( - OpAsmParser &parser, ArrayAttr &options, + OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl &dynamicOptions); static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, - Operation *op, ArrayAttr options, + Operation *op, + DictionaryAttr options, ValueRange dynamicOptions); static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, @@ -784,41 +785,50 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - // Obtain a single options-string from options passed statically as - // string attributes as well as "dynamically" through params. + // Obtain a single options-string to pass to the pass(-pipeline) from options + // passed in as a dictionary of keys mapping to values which are either + // attributes or param-operands pointing to attributes. + std::string options; + llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. + OperandRange dynamicOptions = getDynamicOptions(); - size_t dynamicOptionsIdx = 0; - for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) { + for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) { if (idx > 0) - options += " "; // Interleave options seperator. - - if (auto strAttr = dyn_cast(optionAttr)) { - options += strAttr.getValue(); - } else if (isa(optionAttr)) { - assert(dynamicOptionsIdx < dynamicOptions.size() && + optionsStream << " "; // Interleave options separator. + optionsStream << namedAttribute.getName().str(); // Append the key. + optionsStream << "="; // And the key-value separator. + + Attribute valueAttrToAppend; + if (auto paramOperandIndex = + dyn_cast(namedAttribute.getValue())) { + // The corresponding value attribute is passed in via a param. + // Obtain the param-operand via its specified index. + size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt(); + assert(dynamicOptionIdx < dynamicOptions.size() && "number of dynamic option markers (UnitAttr) in options ArrayAttr " "should be the same as the number of options passed as params"); ArrayRef dynamicOption = - state.getParams(dynamicOptions[dynamicOptionsIdx++]); + state.getParams(dynamicOptions[dynamicOptionIdx]); if (dynamicOption.size() != 1) - return emitSilenceableError() << "options passed as a param must have " - "a single value associated, param " - << dynamicOptionsIdx - 1 << " associates " - << dynamicOption.size(); - - if (auto dynamicOptionStr = dyn_cast(dynamicOption[0])) { - options += dynamicOptionStr.getValue(); - } else { return emitSilenceableError() - << "options passed as a param must be a string, got " - << dynamicOption[0]; - } + << "options passed as a param must have " + "a single value associated, param " + << dynamicOptionIdx << " associates " << dynamicOption.size(); + valueAttrToAppend = dynamicOption[0]; + } else { + // Value is a static attribute. + valueAttrToAppend = namedAttribute.getValue(); + } + + // Append string representation of value attribute. + if (auto strAttr = dyn_cast(valueAttrToAppend)) { + optionsStream << strAttr.getValue().str(); } else { - llvm_unreachable( - "expected options element to be either StringAttr or UnitAttr"); + valueAttrToAppend.print(optionsStream, /*elideType=*/true); } } + optionsStream.flush(); // Get pass or pass pipeline from registry. const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); @@ -864,84 +874,121 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, } static ParseResult parseApplyRegisteredPassOptions( - OpAsmParser &parser, ArrayAttr &options, + OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl &dynamicOptions) { - auto dynamicOptionMarker = UnitAttr::get(parser.getContext()); - SmallVector optionsArray; - - auto parseOperandOrString = [&]() -> OptionalParseResult { - OpAsmParser::UnresolvedOperand operand; - OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand); - if (parsedOperand.has_value()) { - if (failed(parsedOperand.value())) - return failure(); - - dynamicOptions.push_back(operand); - optionsArray.push_back( - dynamicOptionMarker); // Placeholder for knowing where to - // inject the dynamic option-as-param. - return success(); - } + // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax. + SmallVector keyValuePairs; - StringAttr stringAttr; - OptionalParseResult parsedStringAttr = - parser.parseOptionalAttribute(stringAttr); - if (parsedStringAttr.has_value()) { - if (failed(parsedStringAttr.value())) - return failure(); - optionsArray.push_back(stringAttr); - return success(); - } + size_t dynamicOptionsIdx = 0; + auto parseKeyValuePair = [&]() -> ParseResult { + // Parse items of the form `key = value` where `key` is a bare identifier or + // a string and `value` is either an attribute or an operand. + + std::string key; + Attribute valueAttr; + if (parser.parseOptionalKeywordOrString(&key)) + return parser.emitError(parser.getCurrentLocation()) + << "expected key to either be an identifier or a string"; + if (key.empty()) + return failure(); - return std::nullopt; + if (parser.parseEqual()) + return parser.emitError(parser.getCurrentLocation()) + << "expected '=' after key in key-value pair"; + + // Parse the value, which can be either an attribute or an operand. + OptionalParseResult parsedValueAttr = + parser.parseOptionalAttribute(valueAttr); + if (!parsedValueAttr.has_value()) { + OpAsmParser::UnresolvedOperand operand; + ParseResult parsedOperand = parser.parseOperand(operand); + if (failed(parsedOperand)) + return parser.emitError(parser.getCurrentLocation()) + << "expected a valid attribute or operand as value associated " + << "to key '" << key << "'"; + // To make use of the operand, we need to store it in the options dict. + // As SSA-values cannot occur in attributes, what we do instead is store + // an attribute in its place that contains the index of the param-operand, + // so that an attr-value associated to the param can be resolved later on. + dynamicOptions.push_back(operand); + auto wrappedIndex = IntegerAttr::get( + IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++); + valueAttr = + transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex); + } else if (failed(parsedValueAttr.value())) { + return failure(); // NB: Attempted parse should have output error message. + } else if (isa(valueAttr)) { + return parser.emitError(parser.getCurrentLocation()) + << "the param_operand attribute is a marker reserved for " + << "indicating a value will be passed via params and is only used " + << "in the generic print format"; + } + + keyValuePairs.push_back(NamedAttribute(key, valueAttr)); + return success(); }; - OptionalParseResult parsedOptionsElement = parseOperandOrString(); - while (parsedOptionsElement.has_value()) { - if (failed(parsedOptionsElement.value())) - return failure(); - parsedOptionsElement = parseOperandOrString(); - } + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces, + parseKeyValuePair, + " in options dictionary")) + return failure(); // NB: Attempted parse should have output error message. - if (optionsArray.empty()) { + if (DictionaryAttr::findDuplicate( + keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs. + .has_value()) return parser.emitError(parser.getCurrentLocation()) - << "expected at least one option (either a string or a param)"; - } - options = parser.getBuilder().getArrayAttr(optionsArray); + << "duplicate keys found in options dictionary"; + + options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs); + return success(); } static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, - Operation *op, ArrayAttr options, + Operation *op, + DictionaryAttr options, ValueRange dynamicOptions) { - size_t currentDynamicOptionIdx = 0; - for (auto [idx, optionAttr] : llvm::enumerate(options)) { - if (idx > 0) - printer << " "; // Interleave options separator. + if (options.empty()) + return; - if (isa(optionAttr)) - printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]); - else if (auto strAttr = dyn_cast(optionAttr)) - printer.printAttribute(strAttr); - else - llvm_unreachable("each option should be either a StringAttr or UnitAttr"); - } + printer << "{"; + llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) { + printer << namedAttribute.getName() << " = "; + Attribute value = namedAttribute.getValue(); + if (auto indexAttr = dyn_cast(value)) { + // Resolve index of param-operand to its actual SSA-value and print that. + printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]); + } else { + printer.printAttribute(value); + } + }); + printer << "}"; } LogicalResult transform::ApplyRegisteredPassOp::verify() { - size_t numUnitsInOptions = 0; - for (Attribute optionsElement : getOptions()) { - if (isa(optionsElement)) - numUnitsInOptions++; - else if (!isa(optionsElement)) - return emitOpError() << "expected each option to be either a StringAttr " - << "or a UnitAttr, got " << optionsElement; - } - - if (getDynamicOptions().size() != numUnitsInOptions) - return emitOpError() - << "expected the same number of options passed as params as " - << "UnitAttr elements in options ArrayAttr"; + // Check that there is a one-to-one correspondence between param operands + // and references to dynamic options in the options dictionary. + + auto dynamicOptions = SmallVector(getDynamicOptions()); + for (NamedAttribute namedAttr : getOptions()) + if (auto paramOperand = + dyn_cast(namedAttr.getValue())) { + size_t dynamicOptionIdx = paramOperand.getIndex().getInt(); + if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size()) + return emitOpError() + << "dynamic option index " << dynamicOptionIdx + << " is out of bounds for the number of dynamic options: " + << dynamicOptions.size(); + if (dynamicOptions[dynamicOptionIdx] == nullptr) + return emitOpError() << "dynamic option index " << dynamicOptionIdx + << " is already used in options"; + dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used. + } + + for (Value dynamicOption : dynamicOptions) + if (dynamicOption) + return emitOpError() << "a param operand does not have a corresponding " + << "param_operand attr in the options dict"; return success(); } diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 5b158ec6b65fd..10a04b0cc14e0 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,7 +18,12 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union, NewType +from typing import Dict, Optional, Sequence, Union, NewType + + +@register_attribute_builder("ParamOperandAttr") +def _paramOperandAttr(x: int, context) -> Attribute: + return Attribute.parse(f"#transform.param_operand", context=context) @_ods_cext.register_operation(_Dialect, replace=True) @@ -214,6 +219,81 @@ def __init__( super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyRegisteredPassOp(ApplyRegisteredPassOp): + def __init__( + self, + result: Type, + pass_name: Union[str, StringAttr], + target: Union[Operation, Value, OpView], + *, + options: Optional[ + Dict[ + Union[str, StringAttr], + Union[Attribute, Value, Operation, OpView], + ] + ] = None, + loc=None, + ip=None, + ): + options_dict = {} + dynamic_options = [] + + ParamOperandAttr = AttrBuilder.get("ParamOperandAttr") + context = (loc and loc.context) or Context.current + + cur_param_operand_idx = 0 + for key, value in options.items() if options is not None else {}: + if isinstance(key, StringAttr): + key = key.value + + if isinstance(value, (Value, Operation, OpView)): + dynamic_options.append(_get_op_result_or_value(value)) + options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context) + cur_param_operand_idx += 1 + elif isinstance(value, Attribute): + options_dict[key] = value + elif isinstance(value, str): + options_dict[key] = StringAttr.get(value) + else: + raise TypeError(f"Unsupported option type: {type(value)}") + if len(options_dict) > 0: + print(options_dict, cur_param_operand_idx) + super().__init__( + result, + pass_name, + dynamic_options, + target=_get_op_result_or_value(target), + options=DictAttr.get(options_dict), + loc=loc, + ip=ip, + ) + + +def apply_registered_pass( + result: Type, + pass_name: Union[str, StringAttr], + target: Union[Operation, Value, OpView], + *, + options: Optional[ + Dict[ + Union[str, StringAttr], + Union[Attribute, Value, Operation, OpView], + ] + ] = None, + loc=None, + ip=None, +) -> Value: + return ApplyRegisteredPassOp( + result=result, + pass_name=pass_name, + target=target, + options=options, + loc=loc, + ip=ip, + ).result + + AnyOpTypeT = NewType("AnyOpType", AnyOpType) diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 463fd98afa65c..6e6d4eb7e249f 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -80,7 +80,7 @@ module attributes {transform.with_named_sequence} { // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}} // expected-error @below {{: no such option invalid-option}} transform.apply_registered_pass "canonicalize" - with options = "invalid-option=1" to %1 + with options = { "invalid-option" = 1 } to %1 : (!transform.any_op) -> !transform.any_op transform.yield } @@ -97,7 +97,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.apply_registered_pass "canonicalize" - with options = "top-down=false" to %1 + with options = { "top-down" = false } to %1 : (!transform.any_op) -> !transform.any_op transform.yield } @@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op //transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op transform.apply_registered_pass "canonicalize" - with options = "top-down=false test-convergence=true" to %1 + with options = { "top-down" = false, "test-convergence" =true } to %1 : (!transform.any_op) -> !transform.any_op transform.yield } @@ -132,7 +132,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.apply_registered_pass "canonicalize" - with options = "top-down=false" "max-iterations=0" to %1 + with options = { "top-down" = false, "max-iterations" = 0 } to %1 : (!transform.any_op) -> !transform.any_op transform.yield } @@ -148,10 +148,15 @@ func.func @valid_dynamic_pass_options() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param - %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param - %2 = transform.apply_registered_pass "canonicalize" - with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 + %max_iter = transform.param.constant 10 -> !transform.any_param + %max_rewrites = transform.param.constant 1 -> !transform.any_param + %2 = transform.apply_registered_pass + "canonicalize" + with options = { "top-down" = false, + "max-iterations" = %max_iter, + "test-convergence" = true, + "max-num-rewrites" = %max_rewrites } + to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op transform.yield } @@ -159,7 +164,7 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @invalid_dynamic_options_as_array() { +func.func @invalid_options_as_str() { return } @@ -167,34 +172,80 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param - // expected-error @+2 {{expected at least one option (either a string or a param)}} + // expected-error @+2 {{expected '{' in options dictionary}} %2 = transform.apply_registered_pass "canonicalize" - with options = ["top-down=false" %max_iter] to %1 - : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op transform.yield } } // ----- -func.func @invalid_options_as_pairs() { +func.func @invalid_options_as_pairs_without_braces() { return } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @+2 {{expected 'to'}} + // expected-error @+2 {{expected '{' in options dictionary}} %2 = transform.apply_registered_pass "canonicalize" - with options = "top-down=" false to %1 - : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + with options = "top-down"=false to %1 : (!transform.any_op) -> !transform.any_op transform.yield } } // ----- -func.func @invalid_pass_option_param() { +func.func @invalid_options_due_to_reserved_attr() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}} + %2 = transform.apply_registered_pass "canonicalize" + with options = { "top-down" = #transform.param_operand } to %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_options_due_duplicated_key() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+2 {{duplicate keys found in options dictionary}} + %2 = transform.apply_registered_pass "canonicalize" + with options = {"top-down"=false,"top-down"=true} to %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_options_due_invalid_key() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+2 {{expected key to either be an identifier or a string}} + %2 = transform.apply_registered_pass "canonicalize" + with options = { @label = 0 } to %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_pass_option_bare_param() { return } @@ -202,7 +253,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op %pass_options = transform.param.constant 42 -> !transform.any_param - // expected-error @below {{options passed as a param must be a string, got 42}} + // expected-error @+2 {{expected '{' in options dictionary}} transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op @@ -219,12 +270,12 @@ func.func @too_many_pass_option_params() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %x = transform.param.constant "x" -> !transform.any_param - %y = transform.param.constant "y" -> !transform.any_param - %pass_options = transform.merge_handles %x, %y : !transform.any_param + %x = transform.param.constant true -> !transform.any_param + %y = transform.param.constant false -> !transform.any_param + %topdown_options = transform.merge_handles %x, %y : !transform.any_param // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}} transform.apply_registered_pass "canonicalize" - with options = %pass_options to %1 + with options = { "top-down" = %topdown_options } to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op transform.yield } @@ -248,3 +299,77 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +///////////////////////////////////////////////////////////////////// +// Check that the following cases are caugh in the generic format. // +///////////////////////////////////////////////////////////////////// + +// Invalid due to param_operand occurences in options dict not being +// one-to-one with the dynamic options provided as params: +// param_operand_index out of bounds w.r.t. the number of options provided via params. + +"builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param + // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}} + %2 = "transform.apply_registered_pass"(%1, %0) <{ + options = {"max-iterations" = #transform.param_operand, + "test-convergence" = true, + "top-down" = false}, + pass_name = "canonicalize"}> + : (!transform.any_param, !transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () +}) {transform.with_named_sequence} : () -> () + +// ----- + +// Invalid due to param_operand occurences in options dict not being +// one-to-one with the dynamic options provided as params: +// the first option-param is referred to twice and the second one not at all. +// (In the pretty-printed format, if you want to refer to a param SSA-value twice, it counts as two param arguments.) + +"builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param + %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param + // expected-error @below {{dynamic option index 0 is already used in options}} + %3 = "transform.apply_registered_pass"(%1, %2, %0) <{ + options = {"max-iterations" = #transform.param_operand, + "max-num-rewrites" = #transform.param_operand, + "test-convergence" = true, + "top-down" = false}, + pass_name = "canonicalize"}> + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () +}) {transform.with_named_sequence} : () -> () + +// ----- + +// Invalid due to param_operand occurences in options dict not being +// one-to-one with the dynamic options provided as params: +// two option-params are provide though only the first one is referred to from the options-dict. + +"builtin.module"() ({ + "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ + ^bb0(%arg0: !transform.any_op): + %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op + %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param + %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param + // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}} + %3 = "transform.apply_registered_pass"(%1, %2, %0) <{ + options = {"max-iterations" = #transform.param_operand, + "test-convergence" = true, + "top-down" = false}, + pass_name = "canonicalize"}> + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + "transform.yield"() : () -> () + }) : () -> () +}) {transform.with_named_sequence} : () -> () diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 6ed4818fc9d2f..48bc9bad37a1e 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -254,3 +254,55 @@ def testReplicateOp(module: Module): # CHECK: %[[FIRST:.+]] = pdl_match # CHECK: %[[SECOND:.+]] = pdl_match # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] + + +@run +def testApplyRegisteredPassOp(module: Module): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + mod = transform.ApplyRegisteredPassOp( + transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget + ) + mod = transform.ApplyRegisteredPassOp( + transform.AnyOpType.get(), + "canonicalize", + mod.result, + options={"top-down": BoolAttr.get(False)}, + ) + max_iter = transform.param_constant( + transform.AnyParamType.get(), + IntegerAttr.get(IntegerType.get_signless(64), 10), + ) + max_rewrites = transform.param_constant( + transform.AnyParamType.get(), + IntegerAttr.get(IntegerType.get_signless(64), 1), + ) + transform.apply_registered_pass( + transform.AnyOpType.get(), + "canonicalize", + mod, + options={ + "top-down": BoolAttr.get(False), + "max-iterations": max_iter, + "test-convergence": BoolAttr.get(True), + "max-rewrites": max_rewrites, + }, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyRegisteredPassOp + # CHECK: transform.sequence + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" + # CHECK-SAME: with options = {"top-down" = false} + # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op + # CHECK: %[[MAX_ITER:.+]] = transform.param.constant + # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant + # CHECK: %{{.*}} = apply_registered_pass "canonicalize" + # NB: MLIR has sorted the dict lexicographically by key: + # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]], + # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]], + # CHECK-SAME: "test-convergence" = true, + # CHECK-SAME: "top-down" = false} + # CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op