diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp index 9e5bc159dc890..da0ca0e24630f 100644 --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -22,7 +22,7 @@ call_interface_impl::resolveCallable(CallOpInterface call, return symbolVal.getDefiningOp(); // If the callable isn't a value, lookup the symbol reference. - auto symbolRef = callable.get(); + auto symbolRef = cast(callable); if (symbolTable) return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef); return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef); diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 1c661e3beea48..049d7f123cec8 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -95,7 +95,7 @@ findEntryForIntegerType(IntegerType intType, std::map sortedParams; for (DataLayoutEntryInterface entry : params) { sortedParams.insert(std::make_pair( - entry.getKey().get().getIntOrFloatBitWidth(), entry)); + cast(entry.getKey()).getIntOrFloatBitWidth(), entry)); } auto iter = sortedParams.lower_bound(intType.getWidth()); if (iter == sortedParams.end()) @@ -315,9 +315,9 @@ DataLayoutEntryInterface mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id) { const auto *it = llvm::find_if(entries, [id](DataLayoutEntryInterface entry) { - if (!entry.getKey().is()) - return false; - return entry.getKey().get() == id; + if (auto attr = dyn_cast(entry.getKey())) + return attr == id; + return false; }); return it == entries.end() ? DataLayoutEntryInterface() : *it; } @@ -691,7 +691,7 @@ void DataLayoutSpecInterface::bucketEntriesByType( if (auto type = llvm::dyn_cast_if_present(entry.getKey())) types[type.getTypeID()].push_back(entry); else - ids[entry.getKey().get()] = entry; + ids[llvm::cast(entry.getKey())] = entry; } } @@ -709,7 +709,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec, spec.bucketEntriesByType(types, ids); for (const auto &kvp : types) { - auto sampleType = kvp.second.front().getKey().get(); + auto sampleType = cast(kvp.second.front().getKey()); if (isa(sampleType)) { assert(kvp.second.size() == 1 && "expected one data layout entry for non-parametric 'index' type"); @@ -763,7 +763,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec, } for (const auto &kvp : ids) { - StringAttr identifier = kvp.second.getKey().get(); + StringAttr identifier = cast(kvp.second.getKey()); Dialect *dialect = identifier.getReferencedDialect(); // Ignore attributes that belong to an unknown dialect, the dialect may @@ -816,7 +816,7 @@ mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec, // targetDeviceSpec does not support Type as a key. return failure(); } else { - deviceDescKeys[entry.getKey().get()] = entry; + deviceDescKeys[cast(entry.getKey())] = entry; } } } diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 8cc4206dae6ed..3eb401c449980 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -53,7 +53,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, // * Attribute for static dimensions // * Value for dynamic dimensions assert(shapedType.isDynamicDim(dim) == - reifiedReturnShapes[resultIdx][dim].is() && + isa(reifiedReturnShapes[resultIdx][dim]) && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); } ++resultIdx; @@ -70,9 +70,9 @@ bool ShapeAdaptor::hasRank() const { return false; if (auto t = llvm::dyn_cast_if_present(val)) return cast(t).hasRank(); - if (val.is()) + if (isa(val)) return true; - return val.get()->hasRank(); + return cast(val)->hasRank(); } Type ShapeAdaptor::getElementType() const { @@ -80,9 +80,9 @@ Type ShapeAdaptor::getElementType() const { return nullptr; if (auto t = llvm::dyn_cast_if_present(val)) return cast(t).getElementType(); - if (val.is()) + if (isa(val)) return nullptr; - return val.get()->getElementType(); + return cast(val)->getElementType(); } void ShapeAdaptor::getDims(SmallVectorImpl &res) const { @@ -97,7 +97,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl &res) const { for (auto it : dattr.getValues()) res.push_back(it.getSExtValue()); } else { - auto vals = val.get()->getDims(); + auto vals = cast(val)->getDims(); res.assign(vals.begin(), vals.end()); } } @@ -116,7 +116,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const { return cast(attr) .getValues()[index] .getSExtValue(); - auto *stc = val.get(); + auto *stc = cast(val); return stc->getDims()[index]; } @@ -126,7 +126,7 @@ int64_t ShapeAdaptor::getRank() const { return cast(t).getRank(); if (auto attr = llvm::dyn_cast_if_present(val)) return cast(attr).size(); - return val.get()->getDims().size(); + return cast(val)->getDims().size(); } bool ShapeAdaptor::hasStaticShape() const { @@ -142,7 +142,7 @@ bool ShapeAdaptor::hasStaticShape() const { return false; return true; } - auto *stc = val.get(); + auto *stc = cast(val); return llvm::none_of(stc->getDims(), ShapedType::isDynamic); } @@ -162,7 +162,7 @@ int64_t ShapeAdaptor::getNumElements() const { return num; } - auto *stc = val.get(); + auto *stc = cast(val); int64_t num = 1; for (int64_t dim : stc->getDims()) { num *= dim;