Skip to content

Commit d74c976

Browse files
authored
[NFC][MLIR][shard] improving shard docs (llvm#163782)
This PR seeks to improve the clarity of the Shard dialect documentation, in particular the descriptions of the communication operations.
1 parent 5d0a4a1 commit d74c976

File tree

2 files changed

+69
-54
lines changed

2 files changed

+69
-54
lines changed

mlir/docs/Dialects/Shard.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ the tensor is sharded - not specified manually.
2727

2828
### Device Groups
2929

30-
Each collective operation runs within a group of devices. You define groups
31-
using the `grid` and `grid_axes` attributes, which describe how to slice the
32-
full device grid into smaller groups.
30+
Collective operations run within groups of devices, which are defined
31+
using the `grid` and `grid_axes` attributes. These describe
32+
how the full device grid is sliced into smaller groups.
3333

3434
Devices that have the same coordinates *outside* the listed `grid_axes` belong
3535
to the same group.

mlir/include/mlir/Dialect/Shard/IR/ShardOps.td

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
494494
]> {
495495
let summary = "All-gather over a device grid.";
496496
let description = [{
497-
Gathers along the `gather_axis` tensor axis.
497+
Concatenates all tensor slices from a device group defined by `grid_axes` along
498+
the tensor dimension `gather_axis` and replicates the result across all devices
499+
in the group.
498500

499501
Example:
500502
```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
546548
SameOperandsAndResultShape]> {
547549
let summary = "All-reduce over a device grid.";
548550
let description = [{
549-
The accumulation element type is specified by the result type and
550-
it does not need to match the input element type.
551-
The input element is converted to the result element type before
552-
performing the reduction.
551+
Reduces the input tensor across all devices within the groups defined by
552+
`grid_axes`, using the specified reduction method. The operation performs an
553+
element-wise reduction over the tensor slices from all devices in each group.
554+
Each device in a group receives a replicated copy of the reduction result.
555+
The accumulation element type is determined by the result type and does not
556+
need to match the input element type. Before performing the reduction, each
557+
input element is converted to the result element type.
553558

554559
Attributes:
555560
`reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
583588
SameOperandsAndResultElementType,
584589
SameOperandsAndResultRank
585590
]> {
586-
let summary = "All-slice over a device grid. This is the inverse of all-gather.";
591+
let summary = "All-slice over a device grid.";
587592
let description = [{
588-
Slice along the `slice_axis` tensor axis.
589-
This operation can be thought of as the inverse of all-gather.
590-
Technically, it is not required that all processes have the same input tensor.
591-
Each process will slice a piece of its local tensor based on its in-group device index.
592-
The operation does not communicate data between devices.
593+
Within each device group defined by `grid_axes`, slices the input tensor along
594+
the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
595+
the input data is replicated along the `slice_axis`.
596+
Each process simply crops its local data to the slice corresponding to its
597+
in-group device index.
598+
Notice: `AllSliceOp` does not involve any communication between devices and
599+
devices within a group may not have replicated input data.
593600

594601
Example:
595602
```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
610617
```
611618
Result:
612619
```
613-
gather tensor
620+
slice tensor
614621
axis 1
615622
------------>
616623
+-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
646653
SameOperandsAndResultRank]> {
647654
let summary = "All-to-all over a device grid.";
648655
let description = [{
649-
Performs an all-to-all on tensor pieces split along `split_axis`.
650-
The resulting pieces are concatenated along `concat_axis` on ech device.
656+
Each participant logically splits its input along split_axis,
657+
then scatters the resulting pieces across the group defined by `grid_axes`.
658+
After receiving data pieces from other participants' scatters,
659+
it concatenates them along concat_axis to produce the final result.
651660

652661
Example:
653662
```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
702711
]> {
703712
let summary = "Broadcast over a device grid.";
704713
let description = [{
705-
Broadcast the tensor on `root` to all devices in each respective group.
706-
The operation broadcasts along grid axes `grid_axes`.
707-
The `root` device specifies the in-group multi-index that is broadcast to
708-
all other devices in the group.
714+
Copies the input tensor on `root` to all devices in each group defined by
715+
`grid_axes`. The `root` device is defined by its in-group multi-index.
716+
The contents of input tensors on non-root devices are ignored.
709717

710718
Example:
711719
```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
722730
+-------+-------+ | broadcast
723731
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
724732
+-------+-------+ ↓
725-
device (1, 0) -> | | | <- device (1, 1)
733+
device (1, 0) -> | * * | * * | <- device (1, 1)
726734
+-------+-------+
727735
```
728736

@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
758766
]> {
759767
let summary = "Gather over a device grid.";
760768
let description = [{
761-
Gathers on device `root` along the `gather_axis` tensor axis.
762-
`root` specifies the coordinates of a device along `grid_axes`.
763-
It uniquely identifies the root device for each device group.
764-
The result tensor on non-root devices is undefined.
765-
Using it will result in undefined behavior.
769+
Concatenates all tensor slices from a device group defined by `grid_axes` along
770+
the tensor dimension `gather_axis` and returns the resulting tensor on each
771+
`root` device. The result on all other (non-root) devices is undefined.
772+
The `root` device is defined by its in-group multi-index.
766773

767774
Example:
768775
```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
821828
]> {
822829
let summary = "Send over a device grid.";
823830
let description = [{
824-
Receive from a device within a device group.
831+
Receive tensor from device `source`, which is defined by its in-group
832+
multi-index. The groups are defined by `grid_axes`.
833+
The content of input tensor is ignored.
825834
}];
826835
let arguments = !con(commonArgs, (ins
827836
AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
845854
]> {
846855
let summary = "Reduce over a device grid.";
847856
let description = [{
848-
Reduces on device `root` within each device group.
849-
`root` specifies the coordinates of a device along `grid_axes`.
850-
It uniquely identifies the root device within its device group.
851-
The accumulation element type is specified by the result type and
852-
it does not need to match the input element type.
853-
The input element is converted to the result element type before
854-
performing the reduction.
857+
Reduces the input tensor across all devices within the groups defined by
858+
`grid_axes`, using the specified reduction method. The operation performs an
859+
element-wise reduction over the tensor slices from all devices in each group.
860+
The reduction result will be returned on the `root` device of each group.
861+
It is undefined on all other (non-root) devices.
862+
The `root` device is defined by its in-group multi-index.
863+
The accumulation element type is determined by the result type and does not
864+
need to match the input element type. Before performing the reduction, each
865+
input element is converted to the result element type.
855866

856867
Attributes:
857868
`reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
886897
SameOperandsAndResultRank]> {
887898
let summary = "Reduce-scatter over a device grid.";
888899
let description = [{
889-
After the reduction, the result is scattered within each device group.
890-
The tensor is split along `scatter_axis` and the pieces distributed
891-
across the device group.
900+
Reduces the input tensor across all devices within the groups defined by
901+
`grid_axes` using the specified reduction method. The reduction is performed
902+
element-wise across the tensor pieces from all devices in the group.
903+
After reduction, the reduction result is scattered (split and distributed)
904+
across the device group along `scatter_axis`.
892905
Example:
893906
```
894907
shard.grid @grid0(shape = 2x2)
895908
...
896909
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
897910
reduction = <max> scatter_axis = 0
898-
: tensor<3x4xf32> -> tensor<1x4xf64>
911+
: tensor<2x2xf32> -> tensor<1x2xf64>
899912
```
900913
Input:
901914
```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
916929
Result:
917930
```
918931
+-------+
919-
| 6 8 | <- devices (0, 0)
932+
| 5 6 | <- devices (0, 0)
920933
+-------+
921-
| 10 12 | <- devices (0, 1)
934+
| 7 8 | <- devices (0, 1)
922935
+-------+
923-
| 22 24 | <- devices (1, 0)
936+
| 13 14 | <- devices (1, 0)
924937
+-------+
925-
| 26 28 | <- devices (1, 1)
938+
| 15 16 | <- devices (1, 1)
926939
+-------+
927940
```
928941
}];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
950963
]> {
951964
let summary = "Scatter over a device grid.";
952965
let description = [{
953-
For each device group split the input tensor on the `root` device along
954-
axis `scatter_axis` and scatter the parts across the group devices.
966+
For each device group defined by `grid_axes`, the input tensor on the `root`
967+
device is split along axis `scatter_axis` and distributed across the group.
968+
The content of the input on all other (non-root) devices is ignored.
969+
The `root` device is defined by its in-group multi-index.
955970

956971
Example:
957972
```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
968983
(0, 1)
969984
970985
+-------+-------+ | scatter tensor
971-
device (0, 0) -> | | | | axis 0
972-
| | | ↓
986+
device (0, 0) -> | * * | * * | | axis 0
987+
| * * | * * | ↓
973988
+-------+-------+
974989
device (1, 0) -> | 1 2 | 5 6 |
975990
| 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
10181033
]> {
10191034
let summary = "Send over a device grid.";
10201035
let description = [{
1021-
Send from one device to another within a device group.
1036+
Send input tensor to device `destination`, which is defined by its in-group
1037+
multi-index. The groups are defined by `grid_axes`.
10221038
}];
10231039
let arguments = !con(commonArgs, (ins
10241040
AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
10431059
]> {
10441060
let summary = "Shift over a device grid.";
10451061
let description = [{
1046-
Within each device group shift along grid axis `shift_axis` by an offset
1047-
`offset`.
1048-
The result on devices that do not have a corresponding source is undefined.
1049-
`shift_axis` must be one of `grid_axes`.
1050-
If the `rotate` attribute is present,
1051-
instead of a shift a rotation is done.
1062+
Within each device group defined by `grid_axes`, shifts input tensors along the
1063+
device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
1064+
be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
1065+
That is, the offset wraps around according to the group size along `shift_axis`.
1066+
Otherwise, the results on devices without a corresponding source are undefined.
10521067

10531068
Example:
10541069
```

0 commit comments

Comments
 (0)