Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3162,7 +3162,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
isOptional);
continue;
}
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
const NamedAttribute &namedAttr = *cast<NamedAttribute *>(arg);
const Attribute &attr = namedAttr.attr;

// Inferred attributes don't need to be added to the param list.
Expand Down Expand Up @@ -3499,14 +3499,14 @@ void OpEmitter::genSideEffectInterfaceMethods() {
/// Attributes and Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
Argument arg = op.getArg(i);
if (arg.is<NamedTypeConstraint *>()) {
if (isa<NamedTypeConstraint *>(arg)) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
continue;
}
if (arg.is<NamedProperty *>())
if (isa<NamedProperty *>(arg))
continue;
const NamedAttribute *attr = arg.get<NamedAttribute *>();
const NamedAttribute *attr = cast<NamedAttribute *>(arg);
if (attr->attr.getBaseAttr().isSymbolRefAttr())
resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
}
Expand Down Expand Up @@ -3547,7 +3547,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
.str();
} else if (location.kind == EffectKind::Symbol) {
// A symbol reference requires adding the proper attribute.
const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
const auto *attr = cast<NamedAttribute *>(op.getArg(location.index));
std::string argName = op.getGetterName(attr->name);
if (attr->attr.isOptional()) {
body << " if (auto symbolRef = " << argName << "Attr())\n "
Expand Down Expand Up @@ -3648,7 +3648,7 @@ void OpEmitter::genTypeInterfaceMethods() {
// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
<< " = ";
if (op.getDialect().usePropertiesForAttributes()) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ static void populateBuilderArgs(const Operator &op,
name = formatv("_gen_arg_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
if (!op.getArg(i).is<NamedAttribute *>())
if (!isa<NamedAttribute *>(op.getArg(i)))
operandNames.push_back(name);
}
}
Expand Down
18 changes: 9 additions & 9 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,15 +655,15 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
}

// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (opArg.is<NamedAttribute *>()) {
} else if (isa<NamedAttribute *>(opArg)) {
emitAttributeMatch(tree, opName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
Expand All @@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));

// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
Expand Down Expand Up @@ -770,7 +770,7 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
// need to queue the operation only if the matching success. Thus we emit
// the code at the end.
tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
} else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
Expand Down Expand Up @@ -851,7 +851,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os << formatv("tblgen_ops.push_back({0});\n", argName);

os.unindent() << "}\n";
} else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
emitOperandMatch(tree, opName, operandName.str(), operandIndex,
/*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
Expand All @@ -867,7 +867,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;

os << "{\n";
Expand Down Expand Up @@ -1775,7 +1775,7 @@ void PatternEmitter::supplyValuesForOpArgs(
auto patArgName = node.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO: Refactor out into map to avoid recomputing these.
if (!opArg.is<NamedAttribute *>())
if (!isa<NamedAttribute *>(opArg))
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
Expand Down Expand Up @@ -1805,7 +1805,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
bool hasOperandSegmentSizes = false;
std::vector<std::string> sizes;
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) {
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
hasOperandSegmentSizes =
Expand All @@ -1826,7 +1826,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}

const auto *operand =
resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
std::string varName;
if (operand->isVariadic()) {
++numVariadic;
Expand Down
12 changes: 7 additions & 5 deletions mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
#include <optional>

using llvm::ArrayRef;
using llvm::cast;
using llvm::formatv;
using llvm::isa;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::Record;
Expand Down Expand Up @@ -607,11 +609,11 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
bool areOperandsAheadOfAttrs = true;
// Find the first attribute.
const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
return arg.is<NamedAttribute *>();
return isa<NamedAttribute *>(arg);
});
// Check whether all following arguments are attributes.
for (const Argument *ie = op.arg_end(); it != ie; ++it) {
if (!it->is<NamedAttribute *>()) {
if (!isa<NamedAttribute *>(*it)) {
areOperandsAheadOfAttrs = false;
break;
}
Expand Down Expand Up @@ -642,7 +644,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
os << tabs << "{\n";
if (argument.is<NamedTypeConstraint *>()) {
if (isa<NamedTypeConstraint *>(argument)) {
os << tabs
<< formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
operandNum);
Expand All @@ -657,7 +659,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
os << " }\n";
operandNum++;
} else {
NamedAttribute *attr = argument.get<NamedAttribute *>();
NamedAttribute *attr = cast<NamedAttribute *>(argument);
auto newtabs = tabs.str() + " ";
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
Expand Down Expand Up @@ -962,7 +964,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
os << tabs << "}\n";
} else {
os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
auto *attr = argument.get<NamedAttribute *>();
auto *attr = cast<NamedAttribute *>(argument);
auto newtabs = tabs.str() + " ";
emitAttributeDeserialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
Expand Down
Loading