Skip to content

Commit 76a9730

Browse files
Sair Teamcopybara-github
authored andcommitted
Use llvm::cast/dyn_cast/isa since alternatives are deprecated in llvm/llvm-project#135556
PiperOrigin-RevId: 748055275
1 parent df948cf commit 76a9730

14 files changed

+185
-161
lines changed

mapped_domain.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mapped_domain.h"
1616

17+
#include "llvm/Support/Casting.h"
1718
#include "loop_nest.h"
1819

1920
namespace sair {
@@ -31,7 +32,7 @@ DomainShapeAttr MappedDomain::DomainShape() const {
3132
llvm::SmallVector<DomainShapeDim> shape_dims;
3233
shape_dims.reserve(domain_.size());
3334
for (const ValueAccessInstance &access : domain_) {
34-
auto type = access.value.GetType().cast<DimensionType>();
35+
auto type = llvm::cast<DimensionType>(access.value.GetType());
3536
shape_dims.emplace_back(type, access.mapping);
3637
}
3738
return DomainShapeAttr::get(context(), shape_dims);
@@ -64,7 +65,7 @@ mlir::LogicalResult MappedDomain::ResolveUnification(
6465
constraint = MappingDimExpr::get(domain_.size(), context());
6566
assert(dimension.mapping.IsSurjective());
6667
domain_.push_back(dimension);
67-
} else if (auto dim_expr = constraint.dyn_cast<MappingDimExpr>()) {
68+
} else if (auto dim_expr = llvm::dyn_cast<MappingDimExpr>(constraint)) {
6869
// If the dimension must be unified with an existing dimension, ensure that
6970
// they match.
7071
const ValueAccessInstance &old_dimension = domain_[dim_expr.dimension()];

sair_attributes.cc

Lines changed: 54 additions & 52 deletions
Large diffs are not rendered by default.

sair_dialect.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/SmallBitVector.h"
2323
#include "llvm/ADT/SmallVector.h"
2424
#include "llvm/ADT/TypeSwitch.h"
25+
#include "llvm/Support/Casting.h"
2526
#include "llvm/Support/MathExtras.h"
2627
#include "llvm/Support/SMLoc.h"
2728
#include "llvm/Support/raw_ostream.h"
@@ -289,15 +290,15 @@ void PrintMappingExpr(MappingExpr expr, llvm::raw_ostream &os) {
289290
os << MappingNoneExpr::kAttrName;
290291
} else if (expr.isa<MappingUnknownExpr>()) {
291292
os << MappingUnknownExpr::kAttrName;
292-
} else if (auto dim_expr = expr.dyn_cast<MappingDimExpr>()) {
293+
} else if (auto dim_expr = llvm::dyn_cast<MappingDimExpr>(expr)) {
293294
os << "d" << dim_expr.dimension();
294-
} else if (auto stripe_expr = expr.dyn_cast<MappingStripeExpr>()) {
295+
} else if (auto stripe_expr = llvm::dyn_cast<MappingStripeExpr>(expr)) {
295296
os << MappingStripeExpr::kAttrName << "(";
296297
PrintMappingExpr(stripe_expr.operand(), os);
297298
os << ", [";
298299
llvm::interleaveComma(stripe_expr.factors(), os);
299300
os << "])";
300-
} else if (auto unstripe_expr = expr.dyn_cast<MappingUnStripeExpr>()) {
301+
} else if (auto unstripe_expr = llvm::dyn_cast<MappingUnStripeExpr>(expr)) {
301302
os << MappingUnStripeExpr::kAttrName << "(";
302303
for (auto operand : unstripe_expr.operands()) {
303304
PrintMappingExpr(operand, os);
@@ -322,7 +323,7 @@ void Print(StaticRangeType type, mlir::DialectAsmPrinter &os) {
322323

323324
void PrintDomainShapeDim(const DomainShapeDim &dimension,
324325
mlir::DialectAsmPrinter &os) {
325-
if (auto static_range = dimension.type().dyn_cast<StaticRangeType>()) {
326+
if (auto static_range = llvm::dyn_cast<StaticRangeType>(dimension.type())) {
326327
Print(static_range, os);
327328
} else if (dimension.type().isa<DynRangeType>()) {
328329
os << DynRangeType::Name();
@@ -399,13 +400,13 @@ void PrintMapping(MappingAttr mapping, llvm::raw_ostream &os) {
399400
// Prints the Sair type using MLIR printing facilities.
400401
void SairDialect::printType(mlir::Type type,
401402
mlir::DialectAsmPrinter &os) const {
402-
if (auto range_type = type.dyn_cast<DynRangeType>()) {
403+
if (auto range_type = llvm::dyn_cast<DynRangeType>(type)) {
403404
return Print(range_type, os);
404-
} else if (auto static_range_type = type.dyn_cast<StaticRangeType>()) {
405+
} else if (auto static_range_type = llvm::dyn_cast<StaticRangeType>(type)) {
405406
return Print(static_range_type, os);
406407
}
407408

408-
Print(type.cast<ValueType>(), &os);
409+
Print(llvm::cast<ValueType>(type), &os);
409410
}
410411

411412
// Prints the Sair attribute using MLIR printing facilities.

sair_op_interfaces.cc

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/ADT/SmallBitVector.h"
2525
#include "llvm/ADT/SmallVector.h"
2626
#include "llvm/ADT/StringRef.h"
27+
#include "llvm/Support/Casting.h"
2728
#include "mlir/IR/Attributes.h"
2829
#include "mlir/IR/BuiltinAttributes.h"
2930
#include "mlir/IR/BuiltinTypes.h"
@@ -42,7 +43,7 @@
4243
namespace sair {
4344

4445
mlir::Type ValueAccess::ElementType() const {
45-
return value.getType().cast<ValueType>().ElementType();
46+
return llvm::cast<ValueType>(value.getType()).ElementType();
4647
}
4748

4849
bool operator==(const ValueAccess &lhs, const ValueAccess &rhs) {
@@ -60,10 +61,8 @@ ValueOperand::ValueOperand(mlir::OpOperand *operand) : operand_(operand) {
6061
}
6162

6263
MappingAttr ValueOperand::Mapping() const {
63-
return cast<SairOp>(operand_->getOwner())
64-
.getMappingArray()
65-
.getValue()[index_]
66-
.template cast<::sair::MappingAttr>();
64+
return llvm::cast<MappingAttr>(
65+
cast<SairOp>(operand_->getOwner()).getMappingArray().getValue()[index_]);
6766
}
6867

6968
void ValueOperand::SubstituteValue(ValueAccess new_value) {
@@ -138,7 +137,7 @@ static mlir::LogicalResult VerifyDecisionsWellFormed(mlir::Location loc,
138137
}
139138
loop_names.reserve(loop_nest.size());
140139
for (mlir::Attribute attr : loop_nest.getValue()) {
141-
loop_names.insert(attr.cast<LoopAttr>().name());
140+
loop_names.insert(llvm::cast<LoopAttr>(attr).name());
142141
}
143142
}
144143

@@ -170,7 +169,7 @@ static mlir::LogicalResult VerifyInstancesAttr(SairOp op) {
170169
// Ignore incorrect types here, they will be caught by the op verifier.
171170
mlir::Attribute decision_attr =
172171
op.getInstances()->getValue()[decision_index];
173-
DecisionsAttr decisions = decision_attr.dyn_cast<DecisionsAttr>();
172+
DecisionsAttr decisions = llvm::dyn_cast<DecisionsAttr>(decision_attr);
174173
if (!decisions) continue;
175174

176175
if (decisions.copy_of() != nullptr) {
@@ -194,7 +193,7 @@ static mlir::LogicalResult VerifyInstancesAttr(SairOp op) {
194193
for (auto en : llvm::enumerate(decisions.operands().getValue())) {
195194
mlir::Attribute operand_instance = en.value();
196195
if (operand_instance.isa<mlir::UnitAttr>()) continue;
197-
if (auto copy = operand_instance.dyn_cast<CopyAttr>()) {
196+
if (auto copy = llvm::dyn_cast<CopyAttr>(operand_instance)) {
198197
Value operand = op->getOperand(en.index());
199198
auto defining_op = operand.getDefiningOp<ValueProducerOp>();
200199
if (!defining_op) {
@@ -204,7 +203,9 @@ static mlir::LogicalResult VerifyInstancesAttr(SairOp op) {
204203
"cannot have copies";
205204
}
206205
if (copy.getValue() >=
207-
defining_op.GetCopies(operand.cast<OpResult>().getResultNumber())
206+
defining_op
207+
.GetCopies(
208+
llvm::cast<mlir::OpResult>(operand).getResultNumber())
208209
.size()) {
209210
return op->emitError()
210211
<< "operand #" << en.index() << " of instance #"
@@ -215,7 +216,7 @@ static mlir::LogicalResult VerifyInstancesAttr(SairOp op) {
215216

216217
// Ignore incorrect attribute types here, they will be caught by the op
217218
// verifier later.
218-
auto instance = operand_instance.dyn_cast<InstanceAttr>();
219+
auto instance = llvm::dyn_cast<InstanceAttr>(operand_instance);
219220
if (!instance) continue;
220221

221222
// There may be no defining op for operands of some non-compute ops.
@@ -236,7 +237,7 @@ static mlir::LogicalResult VerifyInstancesAttr(SairOp op) {
236237
if (isa<ComputeOp>(op.getOperation())) return mlir::success();
237238

238239
for (mlir::Attribute attr : op.getInstances()->getValue()) {
239-
DecisionsAttr decisions = attr.dyn_cast<DecisionsAttr>();
240+
DecisionsAttr decisions = llvm::dyn_cast<DecisionsAttr>(attr);
240241
if (!decisions) continue;
241242
if (decisions.sequence() != nullptr || decisions.loop_nest() != nullptr ||
242243
decisions.storage() != nullptr || decisions.expansion() != nullptr) {
@@ -290,7 +291,7 @@ mlir::LogicalResult VerifySairOp(Operation *op) {
290291
<< "missing " << SairOp::kMappingAttrName << " attribute";
291292
}
292293
for (mlir::Attribute attr : sair_op.getMappingArray()) {
293-
MappingAttr mapping = attr.cast<MappingAttr>();
294+
MappingAttr mapping = llvm::cast<MappingAttr>(attr);
294295
if (mapping.HasNoneExprs() || mapping.HasUnknownExprs()) {
295296
return mlir::emitError(op->getLoc())
296297
<< "all dimensions of the accessed domain must be mapped";
@@ -319,8 +320,7 @@ mlir::LogicalResult VerifySairOp(Operation *op) {
319320
}
320321

321322
auto expected_shape = sair_op.getShape().AccessedShape(v.Mapping());
322-
auto given_shape =
323-
v.value().getType().template cast<::sair::ValueType>().Shape();
323+
auto given_shape = llvm::cast<ValueType>(v.value().getType()).Shape();
324324
if (expected_shape != given_shape) {
325325
return op->emitError() << "invalid operand shape: expected "
326326
<< expected_shape << ", got " << given_shape;
@@ -337,7 +337,7 @@ mlir::LogicalResult VerifySairOp(Operation *op) {
337337
::sair::DomainShapeAttr results_shape =
338338
sair_op.getShape().Prefix(sair_op.results_rank());
339339
for (mlir::Value result : op->getResults()) {
340-
auto type = result.getType().cast<ShapedType>();
340+
auto type = llvm::cast<ShapedType>(result.getType());
341341
if (type.Shape() != results_shape) {
342342
return op->emitError() << "unexpected shape: expected " << results_shape
343343
<< ", got " << type.Shape();
@@ -361,7 +361,7 @@ mlir::LogicalResult VerifyValueProducerOp(mlir::Operation *operation) {
361361
DomainShapeAttr shape = sair_op.getShape().Prefix(sair_op.results_rank());
362362
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
363363
for (mlir::Attribute attr : op.GetCopies(i)) {
364-
auto decisions = attr.cast<DecisionsAttr>();
364+
auto decisions = llvm::cast<DecisionsAttr>(attr);
365365
if (mlir::failed(VerifyDecisionsWellFormed(
366366
op.getLoc(), shape, {op->getResultTypes()[i]}, decisions))) {
367367
return mlir::failure();
@@ -373,12 +373,12 @@ mlir::LogicalResult VerifyValueProducerOp(mlir::Operation *operation) {
373373
decisions.copy_of().isa<mlir::UnitAttr>()) {
374374
continue;
375375
}
376-
if (auto copy = decisions.copy_of().dyn_cast<CopyAttr>()) {
376+
if (auto copy = llvm::dyn_cast<CopyAttr>(decisions.copy_of())) {
377377
if (copy.getValue() >= op.GetCopies(i).size()) {
378378
return op.emitError() << "'copy_of' refers to non-existent copy";
379379
}
380380
}
381-
if (auto instance = decisions.copy_of().dyn_cast<InstanceAttr>()) {
381+
if (auto instance = llvm::dyn_cast<InstanceAttr>(decisions.copy_of())) {
382382
std::optional<mlir::ArrayAttr> instances = sair_op.getInstances();
383383
if (instances && instance.getValue() >= instances->size()) {
384384
return op.emitError() << "'copy_of' refers to non-existent instance";
@@ -504,7 +504,8 @@ ResultInstance OpInstance::domain(int i) const {
504504
}
505505
mlir::Value dim = op.getDomain()[i];
506506
OpInstance dim_op(llvm::cast<SairOp>(dim.getDefiningOp()));
507-
return ResultInstance(dim_op, dim.cast<OpResult>().getResultNumber());
507+
return ResultInstance(dim_op,
508+
llvm::cast<mlir::OpResult>(dim).getResultNumber());
508509
}
509510

510511
ValueRange OpInstance::GetDomainValues() const {
@@ -569,7 +570,7 @@ DecisionsAttr ComputeOpInstance::GetDecisions() const {
569570
}
570571
llvm::ArrayRef<mlir::Attribute> copies =
571572
GetValueProducer().GetCopies(result());
572-
return copies[index()].cast<DecisionsAttr>();
573+
return llvm::cast<DecisionsAttr>(copies[index()]);
573574
}
574575

575576
void ComputeOpInstance::SetDecisions(DecisionsAttr decisions) {
@@ -599,7 +600,7 @@ BufferAttr ComputeOpInstance::Storage(int result) const {
599600
decisions.storage()[result].isa<mlir::UnitAttr>()) {
600601
return nullptr;
601602
}
602-
return decisions.storage()[result].cast<BufferAttr>();
603+
return llvm::cast<BufferAttr>(decisions.storage()[result]);
603604
}
604605

605606
void ComputeOpInstance::SetStorage(int result, BufferAttr storage) {
@@ -624,13 +625,13 @@ ComputeOp ComputeOpInstance::GetComputeOp() const {
624625
}
625626

626627
ResultInstance ResultInstance::Unique(mlir::Value value) {
627-
OpResult result = value.cast<OpResult>();
628+
OpResult result = llvm::cast<mlir::OpResult>(value);
628629
OpInstance producer = OpInstance::Unique(cast<SairOp>(result.getOwner()));
629630
return ResultInstance(producer, result.getResultNumber());
630631
}
631632

632633
ShapedType ResultInstance::GetType() const {
633-
return GetValue().getType().cast<ShapedType>();
634+
return llvm::cast<ShapedType>(GetValue().getType());
634635
}
635636

636637
mlir::Value ResultInstance::GetValue() const {
@@ -698,7 +699,7 @@ std::optional<ResultInstance> OperandInstance::GetValue() const {
698699
value = owner.ValueOperands()[operand_position_].value();
699700
}
700701

701-
auto result = value.cast<OpResult>();
702+
auto result = llvm::cast<mlir::OpResult>(value);
702703
mlir::Operation *defining_op = result.getOwner();
703704

704705
// TODO(ulysse): allow specifying the instance use in operands. For now, we

0 commit comments

Comments
 (0)