@@ -122,7 +122,7 @@ class PatternEmitter {
122122
123123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124124 // DAG `tree` as an attribute.
125- void emitAttributeMatch (DagNode tree, StringRef opName , int argIndex,
125+ void emitAttributeMatch (DagNode tree, StringRef castedName , int argIndex,
126126 int depth);
127127
128128 // Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
664664 /* variadicSubIndex=*/ std::nullopt );
665665 ++nextOperand;
666666 } else if (isa<NamedAttribute *>(opArg)) {
667- emitAttributeMatch (tree, opName , opArgIdx, depth);
667+ emitAttributeMatch (tree, castedName , opArgIdx, depth);
668668 } else {
669669 PrintFatalError (loc, " unhandled case when matching op" );
670670 }
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
864864 os.unindent () << " }\n " ;
865865}
866866
867- void PatternEmitter::emitAttributeMatch (DagNode tree, StringRef opName ,
867+ void PatternEmitter::emitAttributeMatch (DagNode tree, StringRef castedName ,
868868 int argIndex, int depth) {
869869 Operator &op = tree.getDialectOp (opMap);
870870 auto *namedAttr = cast<NamedAttribute *>(op.getArg (argIndex));
871871 const auto &attr = namedAttr->attr ;
872872
873873 os << " {\n " ;
874- os.indent () << formatv (" auto tblgen_attr = {0}->getAttrOfType<{1}>(\" {2}\" );"
875- " (void)tblgen_attr;\n " ,
876- opName, attr.getStorageType (), namedAttr->name );
874+ if (op.getDialect ().usePropertiesForAttributes ()) {
875+ os.indent () << formatv (" auto tblgen_attr = {0}.getProperties().{1}();\n " ,
876+ castedName, op.getGetterName (namedAttr->name ));
877+ } else {
878+ os.indent () << formatv (
879+ " auto tblgen_attr = {0}->getAttrOfType<{1}>(\" {2}\" );"
880+ " (void)tblgen_attr;\n " ,
881+ castedName, attr.getStorageType (), namedAttr->name );
882+ }
877883
878884 // TODO: This should use getter method to avoid duplication.
879885 if (attr.hasDefaultValue ()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
887893 // That is precisely what getDiscardableAttr() returns on missing
888894 // attributes.
889895 } else {
890- emitMatchCheck (opName , tgfmt (" tblgen_attr" , &fmtCtx),
896+ emitMatchCheck (castedName , tgfmt (" tblgen_attr" , &fmtCtx),
891897 formatv (" \" expected op '{0}' to have attribute '{1}' "
892898 " of type '{2}'\" " ,
893899 op.getOperationName (), namedAttr->name ,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
918924 }
919925 }
920926 emitStaticVerifierCall (
921- verifier, opName , " tblgen_attr" ,
927+ verifier, castedName , " tblgen_attr" ,
922928 formatv (" \" op '{0}' attribute '{1}' failed to satisfy constraint: "
923929 " '{2}'\" " ,
924930 op.getOperationName (), namedAttr->name ,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
15321538 LLVM_DEBUG (llvm::dbgs () << ' \n ' );
15331539
15341540 Operator &resultOp = tree.getDialectOp (opMap);
1541+ bool useProperties = resultOp.getDialect ().usePropertiesForAttributes ();
15351542 auto numOpArgs = resultOp.getNumArgs ();
15361543 auto numPatArgs = tree.getNumArgs ();
15371544
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16231630 createAggregateLocalVarsForOpArgs (tree, childNodeNames, depth);
16241631
16251632 // Then create the op.
1626- os.scope (" " , " \n }\n " ).os << formatv (
1627- " {0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);" ,
1628- valuePackName, resultOp.getQualCppClassName (), locToUse);
1633+ os.scope (" " , " \n }\n " ).os
1634+ << formatv (" {0} = rewriter.create<{1}>({2}, tblgen_values, {3});" ,
1635+ valuePackName, resultOp.getQualCppClassName (), locToUse,
1636+ useProperties ? " tblgen_props" : " tblgen_attrs" );
16291637 return resultValue;
16301638 }
16311639
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16821690 }
16831691 }
16841692 os << formatv (" {0} = rewriter.create<{1}>({2}, tblgen_types, "
1685- " tblgen_values, tblgen_attrs);\n " ,
1686- valuePackName, resultOp.getQualCppClassName (), locToUse);
1693+ " tblgen_values, {3});\n " ,
1694+ valuePackName, resultOp.getQualCppClassName (), locToUse,
1695+ useProperties ? " tblgen_props" : " tblgen_attrs" );
16871696 os.unindent () << " }\n " ;
16881697 return resultValue;
16891698}
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17911800 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
17921801 Operator &resultOp = node.getDialectOp (opMap);
17931802
1803+ bool useProperties = resultOp.getDialect ().usePropertiesForAttributes ();
17941804 auto scope = os.scope ();
17951805 os << formatv (" ::llvm::SmallVector<::mlir::Value, 4> "
17961806 " tblgen_values; (void)tblgen_values;\n " );
1797- os << formatv (" ::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1798- " tblgen_attrs; (void)tblgen_attrs;\n " );
1807+ if (useProperties) {
1808+ os << formatv (" {0}::Properties tblgen_props; (void)tblgen_props;\n " ,
1809+ resultOp.getQualCppClassName ());
1810+ } else {
1811+ os << formatv (" ::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1812+ " tblgen_attrs; (void)tblgen_attrs;\n " );
1813+ }
17991814
1815+ const char *setPropCmd =
1816+ " tblgen_props.{0} = "
1817+ " ::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n " ;
18001818 const char *addAttrCmd =
18011819 " if (auto tmpAttr = {1}) {\n "
18021820 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\" {0}\" ), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18141832 if (!subTree.isNativeCodeCall ())
18151833 PrintFatalError (loc, " only NativeCodeCall allowed in nested dag node "
18161834 " for creating attribute" );
1817- os << formatv (addAttrCmd, opArgName, childNodeNames.lookup (argIndex));
1835+
1836+ if (useProperties) {
1837+ os << formatv (setPropCmd, opArgName, childNodeNames.lookup (argIndex));
1838+ } else {
1839+ os << formatv (addAttrCmd, opArgName, childNodeNames.lookup (argIndex));
1840+ }
18181841 } else {
18191842 auto leaf = node.getArgAsLeaf (argIndex);
18201843 // The argument in the result DAG pattern.
18211844 auto patArgName = node.getArgName (argIndex);
1822- os << formatv (addAttrCmd, opArgName,
1823- handleOpArgument (leaf, patArgName));
1845+ if (useProperties) {
1846+ os << formatv (setPropCmd, opArgName,
1847+ handleOpArgument (leaf, patArgName));
1848+ } else {
1849+ os << formatv (addAttrCmd, opArgName,
1850+ handleOpArgument (leaf, patArgName));
1851+ }
18241852 }
18251853 continue ;
18261854 }
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18761904 const auto *sameVariadicSize =
18771905 resultOp.getTrait (" ::mlir::OpTrait::SameVariadicOperandSize" );
18781906 if (!sameVariadicSize) {
1879- const char *setSizes = R"(
1880- tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1881- rewriter.getDenseI32ArrayAttr({{ {0} }));
1882- )" ;
1883- os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1907+ if (useProperties) {
1908+ const char *setSizes = R"(
1909+ tblgen_props.operandSegmentSizes = {{ {0} };
1910+ )" ;
1911+ os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1912+ } else {
1913+ const char *setSizes = R"(
1914+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1915+ rewriter.getDenseI32ArrayAttr({{ {0} }));
1916+ )" ;
1917+ os.printReindented (formatv (setSizes, llvm::join (sizes, " , " )).str ());
1918+ }
18841919 }
18851920 }
18861921}
0 commit comments