1111
1212include "mlir/Dialect/Mesh/IR/MeshBase.td"
1313include "mlir/Dialect/Shape/IR/ShapeBase.td"
14+ include "mlir/Interfaces/DestinationStyleOpInterface.td"
1415include "mlir/Interfaces/InferTypeOpInterface.td"
1516include "mlir/Interfaces/SideEffectInterfaces.td"
1617include "mlir/IR/BuiltinTypes.td"
@@ -189,23 +190,27 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
189190 `generic`: is not an allowed value inside a shard attribute.
190191
191192 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
192- `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
193- `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
194- halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
195- `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
196- e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
197- `?` indicates dynamic halo sizes.
193+ `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
194+ sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
195+ gets an additional halo of size 1 at the start of the first dimension and a halo
196+ size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
197+ sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
198+ seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
199+
200+ 6. [Optional] Offsets for each shard and sharded tensor dimension.
201+ `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
202+ sharded tensor dimension the offsets (starting index) of all shards in that
203+ dimension and an additional value for the end of the last shard are provided.
204+ For a 1d sharding this means that position `i` has the exclusive prefix sum for
205+ shard `i`, and since only contiguous sharding is supported, its inclusive prefix
206+ sum is at position 'i+1'.
198207
199- 6. [Optional] Sizes of sharded dimensions of each shard.
200- `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
201- device-mesh one value for each sharded tensor dimension.
202208 Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
203- `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
204- the device-mesh will get a shard of shape 16x8x32 and the second device will get a
205- shard of shape 16x24x32.
206- `?` indicates dynamic shard dimensions.
209+ `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
210+ the device-mesh will get a shard of shape 24x20x32 and the second device will get
211+ a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
207212
208- `halo_sizes` and `sharded_dims_sizes ` are mutually exclusive.
213+ `halo_sizes` and `sharded_dims_offsets ` are mutually exclusive.
209214
210215 Examples:
211216
@@ -240,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
240245 // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
241246 // and it has pre-defined shard sizes. The shards of the devices will have
242247 // the following shapes: [4x2, 4x3, 4x4, 4x5]
243- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5 ]
248+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14 ]
244249 %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
245250 ```
246251 }];
@@ -250,8 +255,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
250255 Mesh_MeshAxesArrayAttr:$split_axes,
251256 OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
252257 OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
253- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes ,
254- Variadic<I64>:$dynamic_sharded_dims_sizes ,
258+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets ,
259+ Variadic<I64>:$dynamic_sharded_dims_offsets ,
255260 DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
256261 Variadic<I64>:$dynamic_halo_sizes
257262 );
@@ -263,7 +268,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
263268 `split_axes` `=` $split_axes
264269 (`partial` `=` $partial_type $partial_axes^)?
265270 (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
266- (`sharded_dims_sizes ` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes , $static_sharded_dims_sizes )^)?
271+ (`sharded_dims_offsets ` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets , $static_sharded_dims_offsets )^)?
267272 attr-dict `:` type($result)
268273 }];
269274 let builders = [
@@ -272,16 +277,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
272277 "ArrayRef<MeshAxis>":$partial_axes,
273278 "mesh::ReductionKind":$partial_type,
274279 CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
275- CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes )>,
280+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets )>,
276281 OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
277282 "ArrayRef<MeshAxesAttr>":$split_axes)>,
278283 OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
279284 "ArrayRef<MeshAxesAttr>":$split_axes,
280285 "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
281- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes )>,
286+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets )>,
282287 OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
283288 ];
284289 let hasVerifier = 1;
290+ let hasCanonicalizer = 1;
285291}
286292
287293def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1058,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
10521058}
10531059
10541060def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1055- DeclareOpInterfaceMethods<SymbolUserOpInterface>
1061+ DestinationStyleOpInterface,
1062+ TypesMatchWith<
1063+ "result has same type as destination",
1064+ "result", "destination", "$_self">,
1065+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066+ AttrSizedOperandSegments
10561067]> {
10571068 let summary = "Update halo data.";
10581069 let description = [{
10591070 This operation updates halo regions of shards, e.g. if their sharding
1060- specified halos and the actual tensor data might have changed
1071+ specified halos and the actual tensor/memref data might have changed
10611072 on the remote devices. Changes might be caused by mutating operations
10621073 and/or if the new halo regions are larger than the existing ones.
10631074
1075+ Source and destination might have different halo sizes.
1076+
10641077 Assumes all devices hold tensors with same-sized halo data as specified
1065- by `dynamic/static_halo_sizes`.
1078+ by `source_halo_sizes/static_source_halo_sizes` and
1079+ `destination_halo_sizes/static_destination_halo_sizes` in source shard
1080+ and destination/result shard.
10661081
10671082 `split_axes` specifies for each tensor axis along which mesh axes its halo
10681083 data is updated.
10691084
1070- Optionally resizes to new halo sizes `target_halo_sizes`.
10711085 }];
10721086 let arguments = (ins
1073- AnyNon0RankedMemRef:$input,
1087+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
1088+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
10741089 FlatSymbolRefAttr:$mesh,
10751090 Mesh_MeshAxesArrayAttr:$split_axes,
1076- Variadic<I64>:$dynamic_halo_sizes,
1077- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
1078- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
1091+ Variadic<I64>:$source_halo_sizes,
1092+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093+ Variadic<I64>:$destination_halo_sizes,
1094+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1095+ );
1096+ let results = (outs
1097+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
10791098 );
10801099 let assemblyFormat = [{
1081- $input `on` $mesh
1100+ $source `into` $destination
1101+ `on` $mesh
10821102 `split_axes` `=` $split_axes
1083- (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
1084- (`target_halo_sizes` `=` $target_halo_sizes^)?
1085- attr-dict `:` type($input)
1103+ (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104+ (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105+ attr-dict `:` type($source) `->` type($result)
1106+ }];
1107+ let extraClassDeclaration = [{
1108+ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
10861109 }];
10871110}
10881111#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
0 commit comments