@@ -122,7 +122,7 @@ class PatternEmitter {
122122
123123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124124 // DAG `tree` as an attribute.
125- void emitAttributeMatch (DagNode tree, StringRef opName , int argIndex,
125+ void emitAttributeMatch (DagNode tree, StringRef castedName , int argIndex,
126126 int depth);
127127
128128 // Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
664664 /* variadicSubIndex=*/ std::nullopt );
665665 ++nextOperand;
666666 } else if (isa<NamedAttribute *>(opArg)) {
667- emitAttributeMatch (tree, opName , opArgIdx, depth);
667+ emitAttributeMatch (tree, castedName , opArgIdx, depth);
668668 } else {
669669 PrintFatalError (loc, " unhandled case when matching op" );
670670 }
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
864864 os.unindent () << " }\n " ;
865865}
866866
867- void PatternEmitter::emitAttributeMatch (DagNode tree, StringRef opName ,
867+ void PatternEmitter::emitAttributeMatch (DagNode tree, StringRef castedName ,
868868 int argIndex, int depth) {
869869 Operator &op = tree.getDialectOp (opMap);
870870 auto *namedAttr = cast<NamedAttribute *>(op.getArg (argIndex));
871871 const auto &attr = namedAttr->attr ;
872872
873873 os << " {\n " ;
874- os.indent () << formatv (" auto tblgen_attr = {0}->getAttrOfType<{1}>(\" {2}\" );"
875- " (void)tblgen_attr;\n " ,
876- opName, attr.getStorageType (), namedAttr->name );
874+ if (op.getDialect ().usePropertiesForAttributes ()) {
875+ os.indent () << formatv (" auto tblgen_attr = {0}.getProperties().{1}();\n " ,
876+ castedName, op.getGetterName (namedAttr->name ));
877+ } else {
878+ os.indent () << formatv (
879+ " auto tblgen_attr = {0}->getAttrOfType<{1}>(\" {2}\" );"
880+ " (void)tblgen_attr;\n " ,
881+ castedName, attr.getStorageType (), namedAttr->name );
882+ }
877883
878884 // TODO: This should use getter method to avoid duplication.
879885 if (attr.hasDefaultValue ()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
887893 // That is precisely what getDiscardableAttr() returns on missing
888894 // attributes.
889895 } else {
890- emitMatchCheck (opName , tgfmt (" tblgen_attr" , &fmtCtx),
896+ emitMatchCheck (castedName , tgfmt (" tblgen_attr" , &fmtCtx),
891897 formatv (" \" expected op '{0}' to have attribute '{1}' "
892898 " of type '{2}'\" " ,
893899 op.getOperationName (), namedAttr->name ,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
918924 }
919925 }
920926 emitStaticVerifierCall (
921- verifier, opName , " tblgen_attr" ,
927+ verifier, castedName , " tblgen_attr" ,
922928 formatv (" \" op '{0}' attribute '{1}' failed to satisfy constraint: "
923929 " '{2}'\" " ,
924930 op.getOperationName (), namedAttr->name ,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
15321538 LLVM_DEBUG (llvm::dbgs () << ' \n ' );
15331539
15341540 Operator &resultOp = tree.getDialectOp (opMap);
1541+ bool useProperties = resultOp.getDialect ().usePropertiesForAttributes ();
15351542 auto numOpArgs = resultOp.getNumArgs ();
15361543 auto numPatArgs = tree.getNumArgs ();
15371544
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16231630 createAggregateLocalVarsForOpArgs (tree, childNodeNames, depth);
16241631
16251632 // Then create the op.
1626- os.scope (" " , " \n }\n " ).os << formatv (
1627- " {0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);" ,
1628- valuePackName, resultOp.getQualCppClassName (), locToUse);
1633+ os.scope (" " , " \n }\n " ).os
1634+ << formatv (" {0} = rewriter.create<{1}>({2}, tblgen_values, {3});" ,
1635+ valuePackName, resultOp.getQualCppClassName (), locToUse,
1636+ useProperties ? " tblgen_props" : " tblgen_attrs" );
16291637 return resultValue;
16301638 }
16311639
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16821690 }
16831691 }
16841692 os << formatv (" {0} = rewriter.create<{1}>({2}, tblgen_types, "
1685- " tblgen_values, tblgen_attrs);\n " ,
1686- valuePackName, resultOp.getQualCppClassName (), locToUse);
1693+ " tblgen_values, {3});\n " ,
1694+ valuePackName, resultOp.getQualCppClassName (), locToUse,
1695+ useProperties ? " tblgen_props" : " tblgen_attrs" );
16871696 os.unindent () << " }\n " ;
16881697 return resultValue;
16891698}
@@ -1791,16 +1800,27 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17911800 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
17921801 Operator &resultOp = node.getDialectOp (opMap);
17931802
1803+ bool useProperties = resultOp.getDialect ().usePropertiesForAttributes ();
17941804 auto scope = os.scope ();
17951805 os << formatv (" ::llvm::SmallVector<::mlir::Value, 4> "
17961806 " tblgen_values; (void)tblgen_values;\n " );
1797- os << formatv (" ::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1798- " tblgen_attrs; (void)tblgen_attrs;\n " );
1807+ if (useProperties) {
1808+ os << formatv (" {0}::Properties tblgen_props; (void)tblgen_props;\n " ,
1809+ resultOp.getQualCppClassName ());
1810+ } else {
1811+ os << formatv (" ::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1812+ " tblgen_attrs; (void)tblgen_attrs;\n " );
1813+ }
17991814
1815+ const char *setPropCmd =
1816+ " tblgen_props.{0} = "
1817+ " ::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n " ;
18001818 const char *addAttrCmd =
18011819 " if (auto tmpAttr = {1}) {\n "
18021820 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\" {0}\" ), "
18031821 " tmpAttr);\n }\n " ;
1822+ const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
1823+
18041824 int numVariadic = 0 ;
18051825 bool hasOperandSegmentSizes = false ;
18061826 std::vector<std::string> sizes;
@@ -1814,13 +1834,13 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18141834 if (!subTree.isNativeCodeCall ())
18151835 PrintFatalError (loc, " only NativeCodeCall allowed in nested dag node "
18161836 " for creating attribute" );
1817- os << formatv (addAttrCmd, opArgName, childNodeNames.lookup (argIndex));
1837+
1838+ os << formatv (setterCmd, opArgName, childNodeNames.lookup (argIndex));
18181839 } else {
18191840 auto leaf = node.getArgAsLeaf (argIndex);
18201841 // The argument in the result DAG pattern.
18211842 auto patArgName = node.getArgName (argIndex);
1822- os << formatv (addAttrCmd, opArgName,
1823- handleOpArgument (leaf, patArgName));
1843+ os << formatv (setterCmd, opArgName, handleOpArgument (leaf, patArgName));
18241844 }
18251845 continue ;
18261846 }
@@ -1876,11 +1896,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18761896 const auto *sameVariadicSize =
18771897 resultOp.getTrait (" ::mlir::OpTrait::SameVariadicOperandSize" );
18781898 if (!sameVariadicSize) {
1879- const char *setSizes = R"(
1880- tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1881- rewriter.getDenseI32ArrayAttr({{ {0} }));
1882- )" ;
1883- os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1899+ if (useProperties) {
1900+ const char *setSizes = R"(
1901+ tblgen_props.operandSegmentSizes = {{ {0} };
1902+ )" ;
1903+ os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1904+ } else {
1905+ const char *setSizes = R"(
1906+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1907+ rewriter.getDenseI32ArrayAttr({{ {0} }));
1908+ )" ;
1909+ os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1910+ }
18841911 }
18851912 }
18861913}
0 commit comments