@@ -49,7 +49,7 @@ def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
4949 Example:
5050 ```
5151 // A device grid with 3 axes, the total device number is 4 * 8 * 12
52- // The dimension sizes are 4, 8, 12
52+ // The dimension sizes are 4, 8, 12
5353 shard.grid @grid0(shape = 4x8x12)
5454
5555 // A device grid with 2 axes, the total device number is unknown
@@ -173,8 +173,8 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
173173 %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
174174 ```
175175 The above returns two indices, `633` and `693`, which correspond to the
176- index of the previous process `(1, 1, 3)`, and the next process
177- `(1, 3, 3) along the split axis `1`.
176+ index of the previous process `(1, 1, 3)`, and the next process
177+ `(1, 3, 3)` along the split axis `1`.
178178
179179 A negative value is returned if there is no neighbor in the respective
180180 direction along the given `split_axes`.
@@ -222,20 +222,20 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
222222 size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
223223 sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
224224 seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
225-
225+
226226 4. [Optional] Offsets for each shard and sharded tensor dimension.
227227 `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
228228 sharded tensor dimension the offsets (starting index) of all shards in that
229229 dimension and an additional value for the end of the last shard are provided.
230230 For a 1d sharding this means that position `i` has the exclusive prefix sum for
231231 shard `i`, and since only contiguous sharding is supported, its inclusive prefix
232232 sum is at position 'i+1'.
233-
233+
234234 Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
235235 `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
236236 the device-grid will get a shard of shape 24x20x32 and the second device will get
237237 a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
238-
238+
239239 `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
240240
241241 Examples:
@@ -259,15 +259,15 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
259259 // and it has halo-sizes of 1 and 2 on the sharded dim.
260260 %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
261261 %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
262-
262+
263263 // The tensor is sharded on its second dimension along axis 0 of @grid1d_4
264264 // and it has pre-defined shard sizes. The shards of the devices will have
265265 // the following shapes: [4x2, 4x3, 4x4, 4x5]
266266 %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
267267 %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
268268 ```
269269 }];
270-
270+
271271 let arguments = (ins
272272 FlatSymbolRefAttr:$grid,
273273 Shard_GridAxesArrayAttr:$split_axes,
@@ -389,7 +389,7 @@ def Shard_ShardOp : Shard_Op<"shard", [
389389 %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
390390 ...
391391 }
392-
392+
393393 func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
394394 %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
395395 %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
@@ -589,7 +589,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
589589 This operation can be thought of as the inverse of all-gather.
590590 Technically, it is not required that all processes have the same input tensor.
591591 Each process will slice a piece of its local tensor based on its in-group device index.
592- The operation does not communicate data between devices.
592+ The operation does not communicate data between devices.
593593
594594 Example:
595595 ```mlir
@@ -706,7 +706,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
706706 The operation broadcasts along grid axes `grid_axes`.
707707 The `root` device specifies the in-group multi-index that is broadcast to
708708 all other devices in the group.
709-
709+
710710 Example:
711711 ```
712712 shard.grid @grid0(shape = 2x2)
@@ -716,13 +716,13 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
716716 root = [0]
717717 : (tensor<2xi8>) -> tensor<2xi8>
718718 ```
719-
719+
720720 Input:
721721 ```
722722 +-------+-------+ | broadcast
723723 device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
724724 +-------+-------+ ↓
725- device (1, 0) -> | | | <- device (1, 1)
725+ device (1, 0) -> | | | <- device (1, 1)
726726 +-------+-------+
727727 ```
728728
@@ -978,15 +978,15 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
978978 device
979979 (1, 1)
980980 ```
981-
981+
982982 Result:
983983 ```
984984 device
985985 (0, 1)
986986 ↓
987987 +-------+-------+
988988 device (0, 0) -> | 1 2 | 5 6 |
989- +-------+-------+
989+ +-------+-------+
990990 device (1, 0) -> | 3 4 | 7 8 |
991991 +-------+-------+
992992 ↑
0 commit comments