Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
int64_t inputIndex = it.index();
// Call get<Value>() under the assumption that we're not casting
// dynamism.
Value indexGroupSize = inputShape[inputIndex].get<Value>();
Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
Value indexGroupStaticSizesProduct =
b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
for (const auto &dim : enumerate(tensorType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynamicSizes.push_back(shape[dim.index()].get<Value>());
dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
}
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/DLTI/DLTI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
continue;
}

Type typeSample = kvp.second.front().getKey().get<Type>();
Type typeSample = cast<Type>(kvp.second.front().getKey());
assert(&typeSample.getDialect() !=
typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
"unexpected data layout entry for built-in type");
Expand All @@ -325,7 +325,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
}

for (const auto &kvp : newEntriesForID) {
StringAttr id = kvp.second.getKey().get<StringAttr>();
StringAttr id = cast<StringAttr>(kvp.second.getKey());
Dialect *dialect = id.getReferencedDialect();
if (!entriesForID.count(id)) {
entriesForID[id] = kvp.second;
Expand Down Expand Up @@ -574,7 +574,7 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {

LogicalResult verifyEntry(DataLayoutEntryInterface entry,
Location loc) const final {
StringRef entryName = entry.getKey().get<StringAttr>().strref();
StringRef entryName = cast<StringAttr>(entry.getKey()).strref();
if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
auto value = dyn_cast<StringAttr>(entry.getValue());
if (value &&
Expand Down
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,17 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
// clang-format on

// Return n-D ids for indexing and 1-D size + id for predicate generation.
return IdBuilderResult{
/*mappingIdOps=*/ids,
/*availableMappingSizes=*/
SmallVector<int64_t>{computeProduct(originalBasis)},
// `forallMappingSizes` iterate in the scaled basis, they need to be
// scaled back into the original basis to provide tight
// activeMappingSizes quantities for predication.
/*activeMappingSizes=*/
SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
/*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
return IdBuilderResult{
/*mappingIdOps=*/ids,
/*availableMappingSizes=*/
SmallVector<int64_t>{computeProduct(originalBasis)},
// `forallMappingSizes` iterate in the scaled basis, they need to be
// scaled back into the original basis to provide tight
// activeMappingSizes quantities for predication.
/*activeMappingSizes=*/
SmallVector<int64_t>{computeProduct(forallMappingSizes) *
multiplicity},
/*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
};

return res;
Expand All @@ -144,9 +145,8 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
// In the 3-D mapping case, scale the first dimension by the multiplicity.
SmallVector<Value> scaledIds = ids;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
scaledIds[0] = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})
.get<Value>();
scaledIds[0] = cast<Value>(affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]}));
// In the 3-D mapping case, unscale the first dimension by the multiplicity.
SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
forallMappingSizeInOriginalBasis[0] *= multiplicity;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
}

} // namespace mlir::mesh
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) {
for (DataLayoutEntryInterface entry : params) {
if (!entry.isTypeEntry())
continue;
if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() ==
if (cast<PtrType>(cast<Type>(entry.getKey())).getMemorySpace() ==
type.getMemorySpace()) {
if (auto spec = dyn_cast<SpecAttr>(entry.getValue()))
return spec;
Expand All @@ -55,7 +55,7 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
continue;
uint32_t size = kDefaultPointerSizeBits;
uint32_t abi = kDefaultPointerAlignment;
auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
auto newType = llvm::cast<PtrType>(llvm::cast<Type>(newEntry.getKey()));
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
Expand Down Expand Up @@ -134,7 +134,7 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
for (DataLayoutEntryInterface entry : entries) {
if (!entry.isTypeEntry())
continue;
auto key = entry.getKey().get<Type>();
auto key = llvm::cast<Type>(entry.getKey());
if (!llvm::isa<SpecAttr>(entry.getValue())) {
return emitError(loc) << "expected layout attribute for " << key
<< " to be a #ptr.spec attribute";
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
}

void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
(*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
}

Operation::operand_range FunctionCallOp::getArgOperands() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
OpFoldResult ofr) {
if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
return constantIndex(builder, loc, *i);
return ofr.get<Value>();
return cast<Value>(ofr);
}

static Value tryFoldTensors(Value t) {
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,19 +1475,19 @@ transform::detail::checkApplyToOne(Operation *transformOp,
if (ptr.isNull())
continue;
if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
!ptr.is<Operation *>()) {
!isa<Operation *>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
}
if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
!ptr.is<Attribute>()) {
!isa<Attribute>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
}
if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
!ptr.is<Value>()) {
!isa<Value>(ptr)) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce a Value for result #"
<< res.getResultNumber();
Expand All @@ -1499,7 +1499,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
return llvm::to_vector(llvm::map_range(
range, [](transform::MappedValue value) { return value.get<T>(); }));
range, [](transform::MappedValue value) { return cast<T>(value); }));
}

void transform::detail::setApplyToOneResults(
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<int64_t> &staticVec) {
auto v = llvm::dyn_cast_if_present<Value>(ofr);
if (!v) {
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}
Expand Down Expand Up @@ -212,11 +212,11 @@ decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
if (auto attr = dyn_cast<Attribute>(it)) {
staticValues.push_back(cast<IntegerAttr>(attr).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
dynamicValues.push_back(cast<Value>(it));
}
}
return {staticValues, dynamicValues};
Expand Down Expand Up @@ -294,10 +294,10 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
bool onlyNonNegative, bool onlyNonZero) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
if (isa<Attribute>(ofr))
continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
// Note: All ofrs have index type.
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
continue;
Expand Down
Loading