Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,16 @@ class OneTypedResult {
class Impl
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
mlir::TypedValue<ResultType> getResult() {
return cast<mlir::TypedValue<ResultType>>(
template <typename ValTy>
mlir::TypedValue<ValTy> getResultOfType() {
return mlir::cast<mlir::TypedValue<ValTy>>(
this->getOperation()->getResult(0));
}

mlir::TypedValue<ResultType> getResult() {
return getResultOfType<ResultType>();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work by templating getResult?

    template<typename ValTy = ResultType>
    mlir::TypedValue<ValTy> getResult() {
       return mlir::cast<mlir::TypedValue<ValTy>>(
          this->getOperation()->getResult(0));
    }

If that does not work, I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same: they are not just a cast but a "find" instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work by templating getResult?

This is a nice suggestion. I have tried it, but simply templating the getResult would cause a compilation error, as shown below.

llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:32: error: expected primary-expression before '>' token
   94 |           .getResult<ShapedType>();
      |                                ^
llvm-project/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp:94:34: error: expected primary-expression before ')' token
   94 |           .getResult<ShapedType>();
      |                                  ^

I would name it getResultAs instead, because other methods like getParentOfType are not behaving the same

Thanks for this note. I have updated it to getResultAs.

/// If the operation returns a single value, then the Op can be implicitly
/// converted to a Value. This yields the value of the only result.
operator mlir::TypedValue<ResultType>() { return getResult(); }
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ struct ReinterpretCastOpInterface
Location loc) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
auto resultMemref = reinterpretCast.getResultOfType<BaseMemRefType>();

builder.setInsertionPointAfter(op);

Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
}

builder.setInsertionPointAfterValue(sourceShard);
TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
TypedValue<ShapedType> resultValue =
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
sourceSharding.getMeshAttr().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
.getResult());
.getResultOfType<ShapedType>();

llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
Expand Down Expand Up @@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshSharding sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
TypedValue<ShapedType> targetShard =
builder
.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
.getResult());
.getResultOfType<ShapedType>();
MeshSharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
Expand Down Expand Up @@ -274,8 +274,9 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
TypedValue<ShapedType> targetShard =
builder.create<tensor::CastOp>(targetShape, allGatherResult)
.getResultOfType<ShapedType>();
return {targetShard, targetSharding};
}

Expand Down Expand Up @@ -407,8 +408,9 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
TypedValue<ShapedType> targetShard =
builder.create<tensor::CastOp>(targetShape, allToAllResult)
.getResultOfType<ShapedType>();
return {targetShard, targetSharding};
}

Expand Down
Loading