@@ -788,46 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
788788 // Obtain a single options-string to pass to the pass(-pipeline) from options
789789 // passed in as a dictionary of keys mapping to values which are either
790790 // attributes or param-operands pointing to attributes.
791+ OperandRange dynamicOptions = getDynamicOptions ();
791792
792793 std::string options;
793794 llvm::raw_string_ostream optionsStream (options); // For "printing" attrs.
794795
795- OperandRange dynamicOptions = getDynamicOptions ();
796- for (auto [idx, namedAttribute] : llvm::enumerate (getOptions ())) {
797- if (idx > 0 )
798- optionsStream << " " ; // Interleave options separator.
799- optionsStream << namedAttribute.getName ().str (); // Append the key.
800- optionsStream << " =" ; // And the key-value separator.
801-
802- Attribute valueAttrToAppend;
803- if (auto paramOperandIndex =
804- dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue ())) {
805- // The corresponding value attribute is passed in via a param.
796+ // A helper to convert an option's attribute value into a corresponding
797+ // string representation, with the ability to obtain the attr(s) from a param.
798+ std::function<void (Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800+ // The corresponding value attribute(s) is/are passed in via a param.
806801 // Obtain the param-operand via its specified index.
807- size_t dynamicOptionIdx = paramOperandIndex .getIndex ().getInt ();
802+ size_t dynamicOptionIdx = paramOperand .getIndex ().getInt ();
808803 assert (dynamicOptionIdx < dynamicOptions.size () &&
809- " number of dynamic option markers (UnitAttr) in options ArrayAttr "
804+ " the number of ParamOperandAttrs in the options DictionaryAttr "
810805 " should be the same as the number of options passed as params" );
811- ArrayRef<Attribute> dynamicOption =
806+ ArrayRef<Attribute> attrsAssociatedToParam =
812807 state.getParams (dynamicOptions[dynamicOptionIdx]);
813- if (dynamicOption.size () != 1 )
814- return emitSilenceableError ()
815- << " options passed as a param must have "
816- " a single value associated, param "
817- << dynamicOptionIdx << " associates " << dynamicOption.size ();
818- valueAttrToAppend = dynamicOption[0 ];
819- } else {
820- // Value is a static attribute.
821- valueAttrToAppend = namedAttribute.getValue ();
822- }
823-
824- // Append string representation of value attribute.
825- if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
808+ // Recursive so as to append all attrs associated to the param.
809+ llvm::interleave (attrsAssociatedToParam, optionsStream, appendValueAttr,
810+ " ," );
811+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812+ // Recursive so as to append all nested attrs of the array.
813+ llvm::interleave (arrayAttr, optionsStream, appendValueAttr, " ," );
814+ } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815+ // Convert to unquoted string.
826816 optionsStream << strAttr.getValue ().str ();
827817 } else {
828- valueAttrToAppend.print (optionsStream, /* elideType=*/ true );
818+ // For all other attributes, ask the attr to print itself (without type).
819+ valueAttr.print (optionsStream, /* elideType=*/ true );
829820 }
830- }
821+ };
822+
823+ // Convert the options DictionaryAttr into a single string.
824+ llvm::interleave (
825+ getOptions (), optionsStream,
826+ [&](auto namedAttribute) {
827+ optionsStream << namedAttribute.getName ().str (); // Append the key.
828+ optionsStream << " =" ; // And the key-value separator.
829+ appendValueAttr (namedAttribute.getValue ()); // And the attr's str repr.
830+ },
831+ " " );
831832 optionsStream.flush ();
832833
833834 // Get pass or pass pipeline from registry.
@@ -878,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
878879 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
879880 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
880881 SmallVector<NamedAttribute> keyValuePairs;
881-
882882 size_t dynamicOptionsIdx = 0 ;
883- auto parseKeyValuePair = [&]() -> ParseResult {
884- // Parse items of the form `key = value` where `key` is a bare identifier or
885- // a string and `value` is either an attribute or an operand.
886883
887- std::string key;
888- Attribute valueAttr;
889- if (parser.parseOptionalKeywordOrString (&key))
890- return parser.emitError (parser.getCurrentLocation ())
891- << " expected key to either be an identifier or a string" ;
892- if (key.empty ())
893- return failure ();
884+ // Helper for allowing parsing of option values which can be of the form:
885+ // - a normal attribute
886+ // - an operand (which would be converted to an attr referring to the operand)
887+ // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888+ std::function<ParseResult (Attribute &)> parseValue =
889+ [&](Attribute &valueAttr) -> ParseResult {
890+ // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891+ if (succeeded (parser.parseOptionalLSquare ())) {
892+ SmallVector<Attribute> attrs;
894893
895- if (parser.parseEqual ())
896- return parser.emitError (parser.getCurrentLocation ())
897- << " expected '=' after key in key-value pair" ;
894+ // Recursively parse the array's elements, which might be operands.
895+ if (parser.parseCommaSeparatedList (
896+ AsmParser::Delimiter::None,
897+ [&]() -> ParseResult { return parseValue (attrs.emplace_back ()); },
898+ " in options dictionary" ) ||
899+ parser.parseRSquare ())
900+ return failure (); // NB: Attempted parse should've output error message.
901+
902+ valueAttr = ArrayAttr::get (parser.getContext (), attrs);
903+
904+ return success ();
905+ }
898906
899907 // Parse the value, which can be either an attribute or an operand.
900908 OptionalParseResult parsedValueAttr =
@@ -903,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
903911 OpAsmParser::UnresolvedOperand operand;
904912 ParseResult parsedOperand = parser.parseOperand (operand);
905913 if (failed (parsedOperand))
906- return parser.emitError (parser.getCurrentLocation ())
907- << " expected a valid attribute or operand as value associated "
908- << " to key '" << key << " '" ;
914+ return failure (); // NB: Attempted parse should've output error message.
909915 // To make use of the operand, we need to store it in the options dict.
910916 // As SSA-values cannot occur in attributes, what we do instead is store
911917 // an attribute in its place that contains the index of the param-operand,
@@ -924,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
924930 << " in the generic print format" ;
925931 }
926932
933+ return success ();
934+ };
935+
936+ // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937+ // string and `value` looks like either an attribute or an operand-in-an-attr.
938+ std::function<ParseResult ()> parseKeyValuePair = [&]() -> ParseResult {
939+ std::string key;
940+ Attribute valueAttr;
941+
942+ if (failed (parser.parseOptionalKeywordOrString (&key)) || key.empty ())
943+ return parser.emitError (parser.getCurrentLocation ())
944+ << " expected key to either be an identifier or a string" ;
945+
946+ if (failed (parser.parseEqual ()))
947+ return parser.emitError (parser.getCurrentLocation ())
948+ << " expected '=' after key in key-value pair" ;
949+
950+ if (failed (parseValue (valueAttr)))
951+ return parser.emitError (parser.getCurrentLocation ())
952+ << " expected a valid attribute or operand as value associated "
953+ << " to key '" << key << " '" ;
954+
927955 keyValuePairs.push_back (NamedAttribute (key, valueAttr));
956+
928957 return success ();
929958 };
930959
@@ -951,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
951980 if (options.empty ())
952981 return ;
953982
954- printer << " {" ;
955- llvm::interleaveComma (options, printer, [&](NamedAttribute namedAttribute) {
956- printer << namedAttribute.getName () << " = " ;
957- Attribute value = namedAttribute.getValue ();
958- if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
983+ std::function<void (Attribute)> printOptionValue = [&](Attribute valueAttr) {
984+ if (auto paramOperandAttr =
985+ dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
959986 // Resolve index of param-operand to its actual SSA-value and print that.
960- printer.printOperand (dynamicOptions[indexAttr.getIndex ().getInt ()]);
987+ printer.printOperand (
988+ dynamicOptions[paramOperandAttr.getIndex ().getInt ()]);
989+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990+ // This case is so that ArrayAttr-contained operands are pretty-printed.
991+ printer << " [" ;
992+ llvm::interleaveComma (arrayAttr, printer, printOptionValue);
993+ printer << " ]" ;
961994 } else {
962- printer.printAttribute (value );
995+ printer.printAttribute (valueAttr );
963996 }
997+ };
998+
999+ printer << " {" ;
1000+ llvm::interleaveComma (options, printer, [&](NamedAttribute namedAttribute) {
1001+ printer << namedAttribute.getName ();
1002+ printer << " = " ;
1003+ printOptionValue (namedAttribute.getValue ());
9641004 });
9651005 printer << " }" ;
9661006}
@@ -970,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9701010 // and references to dynamic options in the options dictionary.
9711011
9721012 auto dynamicOptions = SmallVector<Value>(getDynamicOptions ());
973- for (NamedAttribute namedAttr : getOptions ())
974- if (auto paramOperand =
975- dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue ())) {
1013+
1014+ // Helper for option values to mark seen operands as having been seen (once).
1015+ std::function<LogicalResult (Attribute)> checkOptionValue =
1016+ [&](Attribute valueAttr) -> LogicalResult {
1017+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
9761018 size_t dynamicOptionIdx = paramOperand.getIndex ().getInt ();
9771019 if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size ())
9781020 return emitOpError ()
@@ -983,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9831025 return emitOpError () << " dynamic option index " << dynamicOptionIdx
9841026 << " is already used in options" ;
9851027 dynamicOptions[dynamicOptionIdx] = nullptr ; // Mark this option as used.
1028+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1029+ // Recurse into ArrayAttrs as they may contain references to operands.
1030+ for (auto eltAttr : arrayAttr)
1031+ if (failed (checkOptionValue (eltAttr)))
1032+ return failure ();
9861033 }
1034+ return success ();
1035+ };
1036+
1037+ for (NamedAttribute namedAttr : getOptions ())
1038+ if (failed (checkOptionValue (namedAttr.getValue ())))
1039+ return failure ();
9871040
1041+ // All dynamicOptions-params seen in the dict will have been set to null.
9881042 for (Value dynamicOption : dynamicOptions)
9891043 if (dynamicOption)
9901044 return emitOpError () << " a param operand does not have a corresponding "
0 commit comments