Skip to content

Conversation

@xiaoleis-nv
Copy link
Contributor

@xiaoleis-nv xiaoleis-nv commented Dec 18, 2024

Description

This PR adds a getResultAs member to the OneTypedResult class to simplify the result type casting logic.
Casting the result type is necessary when converting between its concrete type and interface type.
Without this member, one typically needs to call the getResult method followed by an explicit cast, which makes the code tedious. Introducing the getResultAs member simplifies this process.

Examples

Before this PR:

auto targetShard = cast<TypedValue<ShapedType>>(
    builder.create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
        .getResult());

With this PR:

auto targetShard = builder.create<AllSliceOp>(sourceShard, mesh,
                                             ArrayRef<MeshAxis>(splitMeshAxis),
                                             splitTensorAxis)
                       .getResultAs<ShapedType>();

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:memref labels Dec 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-memref

Author: None (xiaoleis-nv)

Changes

Description

This PR adds a getResultOfType member to the OneTypedResult class to simplify the result type casting logic.
Casting the result type is necessary when converting between its concrete type and interface type.
Without this member, one typically needs to call the getResult method followed by an explicit cast, which makes the code tedious. Introducing the getResultOfType member simplifies this process.

Examples

Before this PR:

auto targetShard = cast&lt;TypedValue&lt;ShapedType&gt;&gt;(
    builder.create&lt;AllSliceOp&gt;(sourceShard, mesh,
                               ArrayRef&lt;MeshAxis&gt;(splitMeshAxis),
                               splitTensorAxis)
        .getResult());

With this PR:

auto targetShard = builder.create&lt;AllSliceOp&gt;(sourceShard, mesh,
                                             ArrayRef&lt;MeshAxis&gt;(splitMeshAxis),
                                             splitTensorAxis)
                       .getResultOfType&lt;ShapedType&gt;();

Full diff: https://github.com/llvm/llvm-project/pull/120381.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/OpDefinition.h (+7-2)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+12-8)
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..ae28e1251bd954 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -694,11 +694,16 @@ class OneTypedResult {
   class Impl
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
-    mlir::TypedValue<ResultType> getResult() {
-      return cast<mlir::TypedValue<ResultType>>(
+    template <typename ValTy>
+    mlir::TypedValue<ValTy> getResultOfType() {
+      return mlir::cast<mlir::TypedValue<ValTy>>(
           this->getOperation()->getResult(0));
     }
 
+    mlir::TypedValue<ResultType> getResult() {
+      return getResultOfType<ResultType>();
+    }
+
     /// If the operation returns a single value, then the Op can be implicitly
     /// converted to a Value. This yields the value of the only result.
     operator mlir::TypedValue<ResultType>() { return getResult(); }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..6d5a68ef4d0add 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -213,7 +213,7 @@ struct ReinterpretCastOpInterface
     auto reinterpretCast = cast<ReinterpretCastOp>(op);
     auto baseMemref = reinterpretCast.getSource();
     auto resultMemref =
-        cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
+        reinterpretCast.getResultOfType<BaseMemRefType>();
 
     builder.setInsertionPointAfter(op);
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1e..5c268c06db08e6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   }
 
   builder.setInsertionPointAfterValue(sourceShard);
-  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> resultValue =
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
                                sourceSharding.getMeshAttr().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
-          .getResult());
+          .getResultOfType<ShapedType>();
 
   llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
@@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshSharding sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> targetShard =
       builder
           .create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
-          .getResult());
+          .getResultOfType<ShapedType>();
   MeshSharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
@@ -274,8 +274,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allGatherResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 
@@ -407,8 +409,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allToAllResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 

@github-actions
Copy link

github-actions bot commented Dec 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

mlir::TypedValue<ResultType> getResult() {
return getResultOfType<ResultType>();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work by templating getResult?

    template<typename ValTy = ResultType>
    mlir::TypedValue<ValTy> getResult() {
       return mlir::cast<mlir::TypedValue<ValTy>>(
          this->getOperation()->getResult(0));
    }

If that does not work, I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same: they are not just a cast but a "find" instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work by templating getResult?

This is a nice suggestion. I have tried it, but simply templating the getResult would cause a compilation error, as shown below.

llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:32: error: expected primary-expression before '>' token
   94 |           .getResult<ShapedType>();
      |                                ^
llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:34: error: expected primary-expression before ')' token
   94 |           .getResult<ShapedType>();
      |                                  ^

I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same

Thanks for this note. I have updated it to getResultAs.

@xiaoleis-nv xiaoleis-nv changed the title Add OneTypedResult::getResultOfType to simplify the result type casting logic Add OneTypedResult::getResultAs to simplify the result type casting logic Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:memref mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants