From 1012bf39696678453d48e48cd26ad424809a7d88 Mon Sep 17 00:00:00 2001 From: Xiaolei Shi Date: Wed, 18 Dec 2024 00:14:48 -0800 Subject: [PATCH 1/3] add getResultOfType --- mlir/include/mlir/IR/OpDefinition.h | 9 +++++++-- .../Transforms/RuntimeOpVerification.cpp | 2 +- .../Dialect/Mesh/Transforms/Spmdization.cpp | 20 +++++++++++-------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 59f094d669099..ae28e1251bd95 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -694,11 +694,16 @@ class OneTypedResult { class Impl : public TraitBase::Impl> { public: - mlir::TypedValue getResult() { - return cast>( + template + mlir::TypedValue getResultOfType() { + return mlir::cast>( this->getOperation()->getResult(0)); } + mlir::TypedValue getResult() { + return getResultOfType(); + } + /// If the operation returns a single value, then the Op can be implicitly /// converted to a Value. This yields the value of the only result. operator mlir::TypedValue() { return getResult(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 450bfa0cec0c7..6d5a68ef4d0ad 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -213,7 +213,7 @@ struct ReinterpretCastOpInterface auto reinterpretCast = cast(op); auto baseMemref = reinterpretCast.getSource(); auto resultMemref = - cast>(reinterpretCast.getResult()); + reinterpretCast.getResultOfType(); builder.setInsertionPointAfter(op); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 327ea0991e4e1..5c268c06db08e 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder, } builder.setInsertionPointAfterValue(sourceShard); - TypedValue resultValue = cast>( + TypedValue resultValue = builder .create(sourceShard.getLoc(), sourceShard.getType(), sourceSharding.getMeshAttr().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) - .getResult()); + .getResultOfType(); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, @@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { - TypedValue targetShard = cast>( + TypedValue targetShard = builder .create(sourceShard, mesh, ArrayRef(splitMeshAxis), splitTensorAxis) - .getResult()); + .getResultOfType(); MeshSharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); return {targetShard, targetSharding}; @@ -274,8 +274,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = cast>( - builder.create(targetShape, allGatherResult).getResult()); + TypedValue targetShard = + builder + .create(targetShape, allGatherResult) + .getResultOfType(); return {targetShard, targetSharding}; } @@ -407,8 +409,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = cast>( - builder.create(targetShape, allToAllResult).getResult()); + TypedValue targetShard = + builder + .create(targetShape, allToAllResult) + .getResultOfType(); return {targetShard, targetSharding}; } From c556755da0e41cb611b5d349ff83d121337d980e Mon Sep 17 00:00:00 2001 From: Xiaolei Shi Date: Wed, 18 Dec 2024 00:39:39 -0800 Subject: [PATCH 2/3] fix format issue --- .../lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 3 +-- mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 6d5a68ef4d0ad..5ca7108f79a92 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -212,8 +212,7 @@ struct ReinterpretCastOpInterface Location loc) const { auto reinterpretCast = cast(op); auto baseMemref = reinterpretCast.getSource(); - auto resultMemref = - reinterpretCast.getResultOfType(); + auto resultMemref = reinterpretCast.getResultOfType(); builder.setInsertionPointAfter(op); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 5c268c06db08e..6c41ca8edc093 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -275,8 +275,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = - builder - .create(targetShape, allGatherResult) + builder.create(targetShape, allGatherResult) .getResultOfType(); return {targetShard, targetSharding}; } @@ -410,8 +409,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = - builder - .create(targetShape, allToAllResult) + builder.create(targetShape, allToAllResult) .getResultOfType(); return {targetShard, targetSharding}; } From 55bc56b8cf537c9c549ee5e03c6c2bc429545d49 Mon Sep 17 00:00:00 2001 From: Xiaolei Shi Date: Wed, 18 Dec 2024 18:17:21 -0800 Subject: [PATCH 3/3] rename getResultOfType to getResultAs --- mlir/include/mlir/IR/OpDefinition.h | 4 ++-- .../Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 2 +- mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index ae28e1251bd95..827274f09b4b1 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -695,13 +695,13 @@ class OneTypedResult { : public TraitBase::Impl> { public: template - mlir::TypedValue getResultOfType() { + mlir::TypedValue getResultAs() { return mlir::cast>( this->getOperation()->getResult(0)); } mlir::TypedValue getResult() { - return getResultOfType(); + return getResultAs(); } /// If the operation returns a single value, then the Op can be implicitly diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 5ca7108f79a92..1a852ed05096a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -212,7 +212,7 @@ struct ReinterpretCastOpInterface Location loc) const { auto reinterpretCast = cast(op); auto baseMemref = reinterpretCast.getSource(); - auto resultMemref = reinterpretCast.getResultOfType(); + auto resultMemref = reinterpretCast.getResultAs(); builder.setInsertionPointAfter(op); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 6c41ca8edc093..2f1003766dabd 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -91,7 +91,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder, sourceSharding.getMeshAttr().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) - .getResultOfType(); + .getResultAs(); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, @@ -138,7 +138,7 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, .create(sourceShard, mesh, ArrayRef(splitMeshAxis), splitTensorAxis) - .getResultOfType(); + .getResultAs(); MeshSharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); return {targetShard, targetSharding}; @@ -276,7 +276,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = builder.create(targetShape, allGatherResult) - .getResultOfType(); + .getResultAs(); return {targetShard, targetSharding}; } @@ -410,7 +410,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = builder.create(targetShape, allToAllResult) - .getResultOfType(); + .getResultAs(); return {targetShard, targetSharding}; }