Skip to content

Commit 70fb9a5

Browse files
committed
Adding sharding extraction operation and op tests
and handling GetShardingOp in ShardingPropagation
1 parent 1d86186 commit 70fb9a5

File tree

6 files changed

+71
-7
lines changed

6 files changed

+71
-7
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,33 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
318318
"ArrayRef<MeshAxesAttr>":$split_axes,
319319
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
320320
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
321+
OpBuilder<(ins "llvm::StringRef":$mesh,
322+
"ArrayRef<MeshAxesAttr>":$split_axes,
323+
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
324+
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
325+
)>,
321326
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
322327
];
323328
let hasVerifier = 1;
324329
let hasCanonicalizer = 1;
325330
}
326331

332+
def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
333+
let summary = "Get the sharding of the given tensor.";
334+
let description = [{
335+
This operation returns the sharding of the given tensor as a MeshSharding.
336+
}];
337+
let arguments = (ins
338+
AnyRankedTensor:$source
339+
);
340+
let results = (outs
341+
Mesh_Sharding:$result
342+
);
343+
let assemblyFormat = [{
344+
$source attr-dict `:` type($source) `->` type($result)
345+
}];
346+
}
347+
327348
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
328349
let summary = "Get the shard shape of a given process/device.";
329350
let description = [{

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,16 +454,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
454454
ArrayRef<MeshAxesAttr> split_axes,
455455
ArrayRef<MeshAxis> partial_axes,
456456
mesh::ReductionKind partial_type,
457-
ArrayRef<int64_t> static_halo_sizes,
458-
ArrayRef<int64_t> static_sharded_dims_offsets) {
457+
ArrayRef<int64_t> static_halos,
458+
ArrayRef<int64_t> static_offsets) {
459459
return build(
460460
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
461461
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
462462
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
463-
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
464-
::mlir::DenseI64ArrayAttr::get(b.getContext(),
465-
static_sharded_dims_offsets),
466-
{});
463+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
464+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
467465
}
468466

469467
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
@@ -475,6 +473,18 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
475473
{}, {}, {}, {});
476474
}
477475

476+
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
477+
llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
478+
ArrayRef<int64_t> static_halos,
479+
ArrayRef<int64_t> static_offsets) {
480+
return build(
481+
b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
482+
MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
483+
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
484+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
485+
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
486+
}
487+
478488
void ShardingOp::build(
479489
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480490
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
285285
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
286286
if (op->hasTrait<OpTrait::IsTerminator>() ||
287287
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
288-
llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
288+
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
289289
return success();
290290

291291
if (!shardingOp) {

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,15 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
738738
if (isa<ShardingOp>(op)) {
739739
return success();
740740
}
741+
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
742+
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
743+
if (!shardOp) {
744+
return op.emitError("expected a shard op as source of get_sharding");
745+
}
746+
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
747+
spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
748+
return success();
749+
}
741750

742751
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
743752
if (shardOp) {

mlir/test/Dialect/Mesh/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ func.func @mesh_shard_shape() {
164164
return
165165
}
166166

167+
// CHECK-LABEL: func @mesh_get_sharding
168+
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
169+
func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
170+
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
171+
%s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
172+
// CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
173+
%0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
174+
return %0 : !mesh.sharding
175+
}
176+
167177
// CHECK-LABEL: func @mesh_shape
168178
func.func @mesh_shape() -> (index, index) {
169179
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44

55
mesh.mesh @mesh_1d(shape = 2)
66

7+
// CHECK-LABEL: func @return_sharding
8+
func.func @return_sharding(
9+
// CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
10+
%arg0: tensor<2xf32>
11+
// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
12+
) -> (tensor<2xf32>, !mesh.sharding) {
13+
%ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
14+
%sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
15+
// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
16+
%r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
17+
// CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
18+
return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
19+
}
20+
721
// CHECK-LABEL: func @full_replication
822
func.func @full_replication(
923
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>

0 commit comments

Comments
 (0)