-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Migrate away from PointerUnion::dyn_cast (NFC) #123693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Migrate away from PointerUnion::dyn_cast (NFC) #123693
Conversation
Note that PointerUnion::dyn_cast has been soft deprecated in PointerUnion.h: // FIXME: Replace the uses of is(), get() and dyn_cast() with // isa<T>, cast<T> and the llvm::dyn_cast<T>
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kazu Hirata (kazutakahirata) ChangesNote that PointerUnion::dyn_cast has been soft deprecated in // FIXME: Replace the uses of is(), get() and dyn_cast() with Full diff: https://github.com/llvm/llvm-project/pull/123693.diff 9 Files Affected:
diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
index 10e16096211ad7..73cb0319bfd087 100644
--- a/mlir/examples/transform-opt/mlir-transform-opt.cpp
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -120,7 +120,8 @@ class DiagnosticHandlerWrapper {
/// Verifies the captured "expected-*" diagnostics if required.
llvm::LogicalResult verify() const {
if (auto *ptr =
- handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
+ dyn_cast_if_present<mlir::SourceMgrDiagnosticVerifierHandler *>(
+ handler)) {
return ptr->verify();
}
return mlir::success();
@@ -128,7 +129,8 @@ class DiagnosticHandlerWrapper {
/// Destructs the object of the same type as allocated.
~DiagnosticHandlerWrapper() {
- if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
+ if (auto *ptr =
+ dyn_cast_if_present<mlir::SourceMgrDiagnosticHandler *>(handler)) {
delete ptr;
} else {
delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d688d8e2ab6588..7bd6201d4608cf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -119,7 +119,7 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
/// an LLVM constant op.
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
- if (auto attr = foldResult.dyn_cast<Attribute>()) {
+ if (auto attr = dyn_cast_if_present<Attribute>(foldResult)) {
auto intAttr = cast<IntegerAttr>(attr);
return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
}
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 232c9c96dd09fc..4868ab8e49178f 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -210,7 +210,7 @@ LogicalResult transform::applyTransformNamedSequence(
<< "expected one payload to be bound to the first argument, got "
<< bindings.at(0).size();
}
- auto *payloadRoot = bindings.at(0).front().dyn_cast<Operation *>();
+ auto *payloadRoot = dyn_cast_if_present<Operation *>(bindings.at(0).front());
if (!payloadRoot) {
return transformRoot->emitError() << "expected the object bound to the "
"first argument to be an operation";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..c04ddd1922127c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -340,7 +340,7 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
SmallVector<Value> values;
llvm::transform(foldResults, std::back_inserter(values),
[&](OpFoldResult foldResult) {
- if (auto attr = foldResult.dyn_cast<Attribute>())
+ if (auto attr = dyn_cast_if_present<Attribute>(foldResult))
return builder
.create<arith::ConstantIndexOp>(
loc, cast<IntegerAttr>(attr).getInt())
@@ -2880,7 +2880,7 @@ LogicalResult InsertOp::verify() {
return emitOpError(
"expected position attribute rank to match the dest vector rank");
for (auto [idx, pos] : llvm::enumerate(position)) {
- if (auto attr = pos.dyn_cast<Attribute>()) {
+ if (auto attr = dyn_cast_if_present<Attribute>(pos)) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
return emitOpError("expected position attribute #")
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..2481b3e44e7a2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -242,9 +242,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
- (i == 0) ? offset.dyn_cast<Value>()
+ (i == 0) ? dyn_cast_if_present<Value>(offset)
: rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
+ loc, rewriter.getIndexType(),
+ dyn_cast_if_present<Value>(offset),
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp =
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 8e8a433f331df5..f9cbaa9d26740b 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -748,7 +748,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
SmallVector<AffineExpr> dimReplacements, symReplacements;
int64_t numDims = 0;
for (int64_t i = 0; i < map.getNumDims(); ++i) {
- if (auto attr = operands[i].dyn_cast<Attribute>()) {
+ if (auto attr = dyn_cast_if_present<Attribute>(operands[i])) {
dimReplacements.push_back(
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
@@ -758,7 +758,8 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
}
int64_t numSymbols = 0;
for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
- if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
+ if (auto attr =
+ dyn_cast_if_present<Attribute>(operands[i + map.getNumDims()])) {
symReplacements.push_back(
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 969c560c99ab7c..e7620d93697afc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -515,7 +515,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
- if (auto value = ofr.dyn_cast<Value>()) {
+ if (auto value = dyn_cast_if_present<Value>(ofr)) {
assert(value.getType() == resultType &&
"folder produced value of incorrect type");
replacements.push_back(value);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a970cbc5cacebe..39735cd5646a14 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1781,7 +1781,7 @@ void OpEmitter::genPropertiesSupportForBytecode(
writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
}
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
readPropertiesMethod << formatv(
R"(
@@ -1807,7 +1807,8 @@ void OpEmitter::genPropertiesSupportForBytecode(
name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx));
continue;
}
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr =
+ dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
if (namedAttr->isRequired) {
readPropertiesMethod << formatv(R"(
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f425..192d2e273c2d01 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -49,9 +49,9 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
// Check that it got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
- StringAttr callee = callOp.getCallableForCallee()
- .dyn_cast<SymbolRefAttr>()
- .getLeafReference();
+ StringAttr callee =
+ dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee())
+ .getLeafReference();
EXPECT_EQ(callee, "baz");
calleeFound = true;
});
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast
In this case, why do we replace it with dyn_cast_if_present, but not dyn_cast?
That's because llvm-project/llvm/include/llvm/ADT/PointerUnion.h Lines 166 to 170 in 97d691b
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_castIn this case, why do we replace it with
dyn_cast_if_present, but notdyn_cast?That's because
PointerUnion::dyn_castis implemented withdyn_cast_if_presentlike so:llvm-project/llvm/include/llvm/ADT/PointerUnion.h
Lines 166 to 170 in 97d691b
/// Returns the current pointer if it is of the specified pointer type, /// otherwise returns null. template <typename T> inline T dyn_cast() const { return llvm::dyn_cast_if_present<T>(*this); }
I see, thanks! I tracked back to the previous commit and found that it was replaced with dyn_cast_if_present because it seems like (based on the call sites) the semantics of the member dyn_cast are actually dyn_cast_if_present.
In this case, it looks good to me.
This is a very helpful context, thank you! @kazutakahirata , please add some justification for using |
nikic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our general porting approach for this should be to always use dyn_cast (or cast) rather than dyn_cast_if_present if it's possible. Existing uses of PointerUnion::dyn_cast use it as the only available option, not as an explicit statement that they want dyn_cast_if_present behavior. Porting this to dyn_cast_if_present where null is not actually possible makes for confusing code, because now you explicitly make it look like null is a possible value that needs to be handled. For most (if not all) of the uses in this PR, null is not possible because there is a cast<> on the else path.
If we just blindly migrate dyn_cast -> dyn_cast_if_present, we're making code quality worse, not better.
Note that PointerUnion::dyn_cast has been soft deprecated in
PointerUnion.h:
// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast