-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Add OneTypedResult::getResultAs to simplify the result type casting logic #120381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: None (xiaoleis-nv) ChangesDescriptionThis PR adds a ExamplesBefore this PR: auto targetShard = cast<TypedValue<ShapedType>>(
builder.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
.getResult());With this PR: auto targetShard = builder.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
.getResultOfType<ShapedType>();Full diff: https://github.com/llvm/llvm-project/pull/120381.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..ae28e1251bd954 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -694,11 +694,16 @@ class OneTypedResult {
class Impl
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
- mlir::TypedValue<ResultType> getResult() {
- return cast<mlir::TypedValue<ResultType>>(
+ template <typename ValTy>
+ mlir::TypedValue<ValTy> getResultOfType() {
+ return mlir::cast<mlir::TypedValue<ValTy>>(
this->getOperation()->getResult(0));
}
+ mlir::TypedValue<ResultType> getResult() {
+ return getResultOfType<ResultType>();
+ }
+
/// If the operation returns a single value, then the Op can be implicitly
/// converted to a Value. This yields the value of the only result.
operator mlir::TypedValue<ResultType>() { return getResult(); }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..6d5a68ef4d0add 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -213,7 +213,7 @@ struct ReinterpretCastOpInterface
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
- cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
+ reinterpretCast.getResultOfType<BaseMemRefType>();
builder.setInsertionPointAfter(op);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1e..5c268c06db08e6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
}
builder.setInsertionPointAfterValue(sourceShard);
- TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> resultValue =
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
sourceSharding.getMeshAttr().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
- .getResult());
+ .getResultOfType<ShapedType>();
llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
@@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshSharding sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
builder
.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
- .getResult());
+ .getResultOfType<ShapedType>();
MeshSharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
@@ -274,8 +274,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ builder
+ .create<tensor::CastOp>(targetShape, allGatherResult)
+ .getResultOfType<ShapedType>();
return {targetShard, targetSharding};
}
@@ -407,8 +409,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ builder
+ .create<tensor::CastOp>(targetShape, allToAllResult)
+ .getResultOfType<ShapedType>();
return {targetShard, targetSharding};
}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| mlir::TypedValue<ResultType> getResult() { | ||
| return getResultOfType<ResultType>(); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work by templating getResult?
template<typename ValTy = ResultType>
mlir::TypedValue<ValTy> getResult() {
return mlir::cast<mlir::TypedValue<ValTy>>(
this->getOperation()->getResult(0));
}
If that does not work, I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same: they are not just a cast but a "find" instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work by templating getResult?
This is a nice suggestion. I have tried it, but simply templating the getResult would cause a compilation error, as shown below.
llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:32: error: expected primary-expression before '>' token
94 | .getResult<ShapedType>();
| ^
llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:34: error: expected primary-expression before ')' token
94 | .getResult<ShapedType>();
| ^
I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same
Thanks for this note. I have updated it to getResultAs.
Description
This PR adds a
getResultAsmember to theOneTypedResultclass to simplify the result type casting logic.Casting the result type is necessary when converting between its concrete type and interface type.
Without this member, one typically needs to call the
getResultmethod followed by an explicit cast, which makes the code tedious. Introducing thegetResultAsmember simplifies this process.Examples
Before this PR:
auto targetShard = cast<TypedValue<ShapedType>>( builder.create<AllSliceOp>(sourceShard, mesh, ArrayRef<MeshAxis>(splitMeshAxis), splitTensorAxis) .getResult());With this PR:
auto targetShard = builder.create<AllSliceOp>(sourceShard, mesh, ArrayRef<MeshAxis>(splitMeshAxis), splitTensorAxis) .getResultAs<ShapedType>();