diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 59f094d669099..827274f09b4b1 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -694,11 +694,16 @@ class OneTypedResult { class Impl : public TraitBase::Impl> { public: - mlir::TypedValue getResult() { - return cast>( + template + mlir::TypedValue getResultAs() { + return mlir::cast>( this->getOperation()->getResult(0)); } + mlir::TypedValue getResult() { + return getResultAs(); + } + /// If the operation returns a single value, then the Op can be implicitly /// converted to a Value. This yields the value of the only result. operator mlir::TypedValue() { return getResult(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 450bfa0cec0c7..1a852ed05096a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -212,8 +212,7 @@ struct ReinterpretCastOpInterface Location loc) const { auto reinterpretCast = cast(op); auto baseMemref = reinterpretCast.getSource(); - auto resultMemref = - cast>(reinterpretCast.getResult()); + auto resultMemref = reinterpretCast.getResultAs(); builder.setInsertionPointAfter(op); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 327ea0991e4e1..2f1003766dabd 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder, } builder.setInsertionPointAfterValue(sourceShard); - TypedValue resultValue = cast>( + TypedValue resultValue = builder .create(sourceShard.getLoc(), sourceShard.getType(), sourceSharding.getMeshAttr().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) - .getResult()); + .getResultAs(); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, @@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { - TypedValue targetShard = cast>( + TypedValue targetShard = builder .create(sourceShard, mesh, ArrayRef(splitMeshAxis), splitTensorAxis) - .getResult()); + .getResultAs(); MeshSharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); return {targetShard, targetSharding}; @@ -274,8 +274,9 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = cast>( - builder.create(targetShape, allGatherResult).getResult()); + TypedValue targetShard = + builder.create(targetShape, allGatherResult) + .getResultAs(); return {targetShard, targetSharding}; } @@ -407,8 +408,9 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = cast>( - builder.create(targetShape, allToAllResult).getResult()); + TypedValue targetShard = + builder.create(targetShape, allToAllResult) + .getResultAs(); return {targetShard, targetSharding}; }