Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mlir/test/mlir-tblgen/rewriter-attributes-properties.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

def Test_Dialect : Dialect {
let name = "test";
}
class NS_Op<string mnemonic, list<Trait> traits> :
Op<Test_Dialect, mnemonic, traits>;

def AOp : NS_Op<"a_op", []> {
let arguments = (ins
I32:$x,
I32Attr:$y
);

let results = (outs I32:$z);
}

def BOp : NS_Op<"b_op", []> {
let arguments = (ins
I32Attr:$y
);

let results = (outs I32:$z);
}

def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
// CHECK-LABEL: struct test1
// CHECK: ::llvm::LogicalResult matchAndRewrite
// CHECK-DAG: ::mlir::IntegerAttr y;
// CHECK-DAG: test::BOp x;
// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
// CHECK: tblgen_ops.push_back(op0);
// CHECK: x = castedOp1;
// CHECK: tblgen_attr = castedOp1.getProperties().getY();
// CHECK: if (!(tblgen_attr))
// CHECK: y = tblgen_attr;
// CHECK: tblgen_ops.push_back(op1);

// CHECK: test::AOp tblgen_AOp_0;
// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
// CHECK: test::AOp::Properties tblgen_props;
// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
73 changes: 50 additions & 23 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class PatternEmitter {

// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
int depth);

// Emits C++ for checking a match with a corresponding match failure
Expand Down Expand Up @@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (isa<NamedAttribute *>(opArg)) {
emitAttributeMatch(tree, opName, opArgIdx, depth);
emitAttributeMatch(tree, castedName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
Expand Down Expand Up @@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os.unindent() << "}\n";
}

void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;

os << "{\n";
os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
"(void)tblgen_attr;\n",
opName, attr.getStorageType(), namedAttr->name);
if (op.getDialect().usePropertiesForAttributes()) {
os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
castedName, op.getGetterName(namedAttr->name));
} else {
os.indent() << formatv(
"auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
"(void)tblgen_attr;\n",
castedName, attr.getStorageType(), namedAttr->name);
}

// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
Expand All @@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// That is precisely what getDiscardableAttr() returns on missing
// attributes.
} else {
emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"",
op.getOperationName(), namedAttr->name,
Expand Down Expand Up @@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
}
}
emitStaticVerifierCall(
verifier, opName, "tblgen_attr",
verifier, castedName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"'{2}'\"",
op.getOperationName(), namedAttr->name,
Expand Down Expand Up @@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
LLVM_DEBUG(llvm::dbgs() << '\n');

Operator &resultOp = tree.getDialectOp(opMap);
bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();

Expand Down Expand Up @@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);

// Then create the op.
os.scope("", "\n}\n").os << formatv(
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
valuePackName, resultOp.getQualCppClassName(), locToUse);
os.scope("", "\n}\n").os
<< formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
valuePackName, resultOp.getQualCppClassName(), locToUse,
useProperties ? "tblgen_props" : "tblgen_attrs");
return resultValue;
}

Expand Down Expand Up @@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
}
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
"tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName(), locToUse);
"tblgen_values, {3});\n",
valuePackName, resultOp.getQualCppClassName(), locToUse,
useProperties ? "tblgen_props" : "tblgen_attrs");
os.unindent() << "}\n";
return resultValue;
}
Expand Down Expand Up @@ -1791,16 +1800,27 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);

bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto scope = os.scope();
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
"tblgen_attrs; (void)tblgen_attrs;\n");
if (useProperties) {
os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
resultOp.getQualCppClassName());
} else {
os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
"tblgen_attrs; (void)tblgen_attrs;\n");
}

const char *setPropCmd =
"tblgen_props.{0} = "
"::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
const char *addAttrCmd =
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
"tmpAttr);\n}\n";
const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;

int numVariadic = 0;
bool hasOperandSegmentSizes = false;
std::vector<std::string> sizes;
Expand All @@ -1814,13 +1834,13 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));

os << formatv(setterCmd, opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
os << formatv(addAttrCmd, opArgName,
handleOpArgument(leaf, patArgName));
os << formatv(setterCmd, opArgName, handleOpArgument(leaf, patArgName));
}
continue;
}
Expand Down Expand Up @@ -1876,11 +1896,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
const auto *sameVariadicSize =
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
if (!sameVariadicSize) {
const char *setSizes = R"(
tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
rewriter.getDenseI32ArrayAttr({{ {0} }));
)";
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
if (useProperties) {
const char *setSizes = R"(
tblgen_props.operandSegmentSizes = {{ {0} };
)";
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
} else {
const char *setSizes = R"(
tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
rewriter.getDenseI32ArrayAttr({{ {0} }));
)";
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
}
}
}
}
Expand Down