Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 41 additions & 40 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DefInit;
using llvm::Init;
using llvm::ListInit;
using llvm::Record;
using llvm::RecordVal;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
Expand All @@ -35,14 +41,13 @@ bool AttrOrTypeBuilder::hasInferredContextParameter() const {
// AttrOrTypeDef
//===----------------------------------------------------------------------===//

AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
AttrOrTypeDef::AttrOrTypeDef(const Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
const auto *builderList =
dyn_cast_or_null<ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (const llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
for (const Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<DefInit>(init)->getDef(), def->getLoc());

// Ensure that all parameters have names.
for (const AttrOrTypeBuilder::Parameter &param :
Expand All @@ -56,16 +61,16 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {

// Populate the traits.
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
SmallPtrSet<const Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
llvm::unique_function<void(const ListInit *)> processTraitList =
[&](const ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;

// If this is an interface, add any bases to the trait list.
auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
auto *traitDef = cast<DefInit>(traitInit)->getDef();
if (traitDef->isSubClassOf("Interface")) {
if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
processTraitList(bases);
Expand Down Expand Up @@ -111,7 +116,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
}

Dialect AttrOrTypeDef::getDialect() const {
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
const auto *dialect = dyn_cast<DefInit>(def->getValue("dialect")->getValue());
return Dialect(dialect ? dialect->getDef() : nullptr);
}

Expand All @@ -126,17 +131,17 @@ StringRef AttrOrTypeDef::getCppBaseClassName() const {
}

bool AttrOrTypeDef::hasDescription() const {
const llvm::RecordVal *desc = def->getValue("description");
return desc && isa<llvm::StringInit>(desc->getValue());
const RecordVal *desc = def->getValue("description");
return desc && isa<StringInit>(desc->getValue());
}

StringRef AttrOrTypeDef::getDescription() const {
return def->getValueAsString("description");
}

bool AttrOrTypeDef::hasSummary() const {
const llvm::RecordVal *summary = def->getValue("summary");
return summary && isa<llvm::StringInit>(summary->getValue());
const RecordVal *summary = def->getValue("summary");
return summary && isa<StringInit>(summary->getValue());
}

StringRef AttrOrTypeDef::getSummary() const {
Expand Down Expand Up @@ -249,9 +254,9 @@ StringRef TypeDef::getTypeName() const {
template <typename InitT>
auto AttrOrTypeParameter::getDefValue(StringRef name) const {
std::optional<decltype(std::declval<InitT>().getValue())> result;
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (auto *init = param->getDef()->getValue(name))
if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (const auto *init = param->getDef()->getValue(name))
if (const auto *value = dyn_cast_or_null<InitT>(init->getValue()))
result = value->getValue();
return result;
}
Expand All @@ -270,20 +275,20 @@ std::string AttrOrTypeParameter::getAccessorName() const {
}

std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<llvm::StringInit>("allocator");
return getDefValue<StringInit>("allocator");
}

StringRef AttrOrTypeParameter::getComparator() const {
return getDefValue<llvm::StringInit>("comparator").value_or("$_lhs == $_rhs");
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
}

StringRef AttrOrTypeParameter::getCppType() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
auto cppType = getDefValue<llvm::StringInit>("cppType");
auto cppType = getDefValue<StringInit>("cppType");
if (cppType)
return *cppType;
if (auto *init = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *init = dyn_cast<DefInit>(getDef()))
llvm::PrintFatalError(
init->getDef()->getLoc(),
Twine("Missing `cppType` field in Attribute/Type parameter: ") +
Expand All @@ -295,52 +300,48 @@ StringRef AttrOrTypeParameter::getCppType() const {
}

StringRef AttrOrTypeParameter::getCppAccessorType() const {
return getDefValue<llvm::StringInit>("cppAccessorType")
.value_or(getCppType());
return getDefValue<StringInit>("cppAccessorType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getCppStorageType() const {
return getDefValue<llvm::StringInit>("cppStorageType").value_or(getCppType());
return getDefValue<StringInit>("cppStorageType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getConvertFromStorage() const {
return getDefValue<llvm::StringInit>("convertFromStorage").value_or("$_self");
return getDefValue<StringInit>("convertFromStorage").value_or("$_self");
}

std::optional<StringRef> AttrOrTypeParameter::getParser() const {
return getDefValue<llvm::StringInit>("parser");
return getDefValue<StringInit>("parser");
}

std::optional<StringRef> AttrOrTypeParameter::getPrinter() const {
return getDefValue<llvm::StringInit>("printer");
return getDefValue<StringInit>("printer");
}

std::optional<StringRef> AttrOrTypeParameter::getSummary() const {
return getDefValue<llvm::StringInit>("summary");
return getDefValue<StringInit>("summary");
}

StringRef AttrOrTypeParameter::getSyntax() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
return getDefValue<llvm::StringInit>("syntax").value_or(getCppType());
return getDefValue<StringInit>("syntax").value_or(getCppType());
}

bool AttrOrTypeParameter::isOptional() const {
return getDefaultValue().has_value();
}

std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
std::optional<StringRef> result =
getDefValue<llvm::StringInit>("defaultValue");
std::optional<StringRef> result = getDefValue<StringInit>("defaultValue");
return result && !result->empty() ? result : std::nullopt;
}

const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
const Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (param->getDef()->isSubClassOf("Constraint"))
return Constraint(param->getDef());
return std::nullopt;
Expand All @@ -351,8 +352,8 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//

bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
const Init *paramDef = param->getDef();
if (const auto *paramDefInit = dyn_cast<DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
}
25 changes: 12 additions & 13 deletions mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ StringRef Attribute::getReturnType() const {
// Return the type constraint corresponding to the type of this attribute, or
// std::nullopt if this is not a TypedAttr.
std::optional<Type> Attribute::getValueType() const {
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("valueType")))
return Type(defInit->getDef());
return std::nullopt;
}
Expand All @@ -92,8 +92,7 @@ StringRef Attribute::getConstBuilderTemplate() const {
}

Attribute Attribute::getBaseAttr() const {
if (const auto *defInit =
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("baseAttr"))) {
return Attribute(defInit).getBaseAttr();
}
return *this;
Expand Down Expand Up @@ -132,7 +131,7 @@ Dialect Attribute::getDialect() const {
return Dialect(nullptr);
}

const llvm::Record &Attribute::getDef() const { return *def; }
const Record &Attribute::getDef() const { return *def; }

ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
Expand All @@ -147,12 +146,12 @@ StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}

EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}

EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
EnumAttrCase::EnumAttrCase(const DefInit *init)
: EnumAttrCase(init->getDef()) {}

StringRef EnumAttrCase::getSymbol() const {
Expand All @@ -163,16 +162,16 @@ StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }

int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }

const llvm::Record &EnumAttrCase::getDef() const { return *def; }
const Record &EnumAttrCase::getDef() const { return *def; }

EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}

EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}

EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}

bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo");
Expand Down Expand Up @@ -218,8 +217,8 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
std::vector<EnumAttrCase> cases;
cases.reserve(inits->size());

for (const llvm::Init *init : *inits) {
cases.emplace_back(cast<llvm::DefInit>(init));
for (const Init *init : *inits) {
cases.emplace_back(cast<DefInit>(init));
}

return cases;
Expand All @@ -229,7 +228,7 @@ bool EnumAttr::genSpecializedAttr() const {
return def->getValueAsBit("genSpecializedAttr");
}

const llvm::Record *EnumAttr::getBaseAttrClass() const {
const Record *EnumAttr::getBaseAttrClass() const {
return def->getValueAsDef("baseAttrClass");
}

Expand Down
24 changes: 14 additions & 10 deletions mlir/lib/TableGen/Builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// Builder::Parameter
//===----------------------------------------------------------------------===//

/// Return a string containing the C++ type of this parameter.
StringRef Builder::Parameter::getCppType() const {
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
if (const auto *stringInit = dyn_cast<StringInit>(def))
return stringInit->getValue();
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
// Inlining the first part of `Record::getValueAsString` to give better
// error messages.
const llvm::RecordVal *type = record->getValue("type");
Expand All @@ -35,9 +40,9 @@ StringRef Builder::Parameter::getCppType() const {
/// Return an optional string containing the default value to use for this
/// parameter.
std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
if (isa<llvm::StringInit>(def))
if (isa<StringInit>(def))
return std::nullopt;
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
std::optional<StringRef> value =
record->getValueAsOptionalString("defaultValue");
return value && !value->empty() ? value : std::nullopt;
Expand All @@ -47,18 +52,17 @@ std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
// Builder
//===----------------------------------------------------------------------===//

Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
: def(record) {
Builder::Builder(const Record *record, ArrayRef<SMLoc> loc) : def(record) {
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
const DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<DefInit>(dag->getOperator());
if (!defInit || defInit->getDef()->getName() != "ins")
PrintFatalError(def->getLoc(), "expected 'ins' in builders");

bool seenDefaultValue = false;
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
const llvm::StringInit *paramName = dag->getArgName(i);
const llvm::Init *paramValue = dag->getArg(i);
const StringInit *paramName = dag->getArgName(i);
const Init *paramValue = dag->getArg(i);
Parameter param(paramName ? paramName->getValue()
: std::optional<StringRef>(),
paramValue);
Expand Down
Loading
Loading