diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td new file mode 100644 index 0000000000000..fc36a51789ec2 --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -0,0 +1,47 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} +class NS_Op traits> : + Op; + +def AOp : NS_Op<"a_op", []> { + let arguments = (ins + I32:$x, + I32Attr:$y + ); + + let results = (outs I32:$z); +} + +def BOp : NS_Op<"b_op", []> { + let arguments = (ins + I32Attr:$y + ); + + let results = (outs I32:$z); +} + +def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; +// CHECK-LABEL: struct test1 +// CHECK: ::llvm::LogicalResult matchAndRewrite +// CHECK-DAG: ::mlir::IntegerAttr y; +// CHECK-DAG: test::BOp x; +// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; +// CHECK: tblgen_ops.push_back(op0); +// CHECK: x = castedOp1; +// CHECK: tblgen_attr = castedOp1.getProperties().getY(); +// CHECK: if (!(tblgen_attr)) +// CHECK: y = tblgen_attr; +// CHECK: tblgen_ops.push_back(op1); + +// CHECK: test::AOp tblgen_AOp_0; +// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; +// CHECK: test::AOp::Properties tblgen_props; +// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin())); +// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present(y); +// CHECK: tblgen_AOp_0 = rewriter.create(odsLoc, tblgen_types, tblgen_values, tblgen_props); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index f6eb5bdfe568e..f921788abdd71 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -122,7 +122,7 @@ class PatternEmitter { // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, + void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, int depth); // Emits C++ for checking a match with a corresponding match failure @@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { /*variadicSubIndex=*/std::nullopt); ++nextOperand; } else if (isa(opArg)) { - emitAttributeMatch(tree, opName, opArgIdx, depth); + emitAttributeMatch(tree, castedName, opArgIdx, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree, os.unindent() << "}\n"; } -void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, +void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = cast(op.getArg(argIndex)); const auto &attr = namedAttr->attr; os << "{\n"; - os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" - "(void)tblgen_attr;\n", - opName, attr.getStorageType(), namedAttr->name); + if (op.getDialect().usePropertiesForAttributes()) { + os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n", + castedName, op.getGetterName(namedAttr->name)); + } else { + os.indent() << formatv( + "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" + "(void)tblgen_attr;\n", + castedName, attr.getStorageType(), namedAttr->name); + } // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, // That is precisely what getDiscardableAttr() returns on missing // attributes. } else { - emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, @@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, } } emitStaticVerifierCall( - verifier, opName, "tblgen_attr", + verifier, castedName, "tblgen_attr", formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "'{2}'\"", op.getOperationName(), namedAttr->name, @@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, LLVM_DEBUG(llvm::dbgs() << '\n'); Operator &resultOp = tree.getDialectOp(opMap); + bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); @@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then create the op. - os.scope("", "\n}\n").os << formatv( - "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", - valuePackName, resultOp.getQualCppClassName(), locToUse); + os.scope("", "\n}\n").os + << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});", + valuePackName, resultOp.getQualCppClassName(), locToUse, + useProperties ? "tblgen_props" : "tblgen_attrs"); return resultValue; } @@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, } } os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " - "tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName(), locToUse); + "tblgen_values, {3});\n", + valuePackName, resultOp.getQualCppClassName(), locToUse, + useProperties ? "tblgen_props" : "tblgen_attrs"); os.unindent() << "}\n"; return resultValue; } @@ -1791,16 +1800,27 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); + bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); auto scope = os.scope(); os << formatv("::llvm::SmallVector<::mlir::Value, 4> " "tblgen_values; (void)tblgen_values;\n"); - os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " - "tblgen_attrs; (void)tblgen_attrs;\n"); + if (useProperties) { + os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n", + resultOp.getQualCppClassName()); + } else { + os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " + "tblgen_attrs; (void)tblgen_attrs;\n"); + } + const char *setPropCmd = + "tblgen_props.{0} = " + "::llvm::dyn_cast_if_present({1});\n"; const char *addAttrCmd = "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " "tmpAttr);\n}\n"; + const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd; + int numVariadic = 0; bool hasOperandSegmentSizes = false; std::vector sizes; @@ -1814,13 +1834,13 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); - os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); + + os << formatv(setterCmd, opArgName, childNodeNames.lookup(argIndex)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); - os << formatv(addAttrCmd, opArgName, - handleOpArgument(leaf, patArgName)); + os << formatv(setterCmd, opArgName, handleOpArgument(leaf, patArgName)); } continue; } @@ -1876,11 +1896,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( const auto *sameVariadicSize = resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); if (!sameVariadicSize) { - const char *setSizes = R"( - tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), - rewriter.getDenseI32ArrayAttr({{ {0} })); - )"; - os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + if (useProperties) { + const char *setSizes = R"( + tblgen_props.operandSegmentSizes = {{ {0} }; + )"; + os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + } else { + const char *setSizes = R"( + tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), + rewriter.getDenseI32ArrayAttr({{ {0} })); + )"; + os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + } } } }