@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
494494 ]> {
495495 let summary = "All-gather over a device grid.";
496496 let description = [{
497- Gathers along the `gather_axis` tensor axis.
497+ Concatenates all tensor slices from a device group defined by `grid_axes` along
498+ the tensor dimension `gather_axis` and replicates the result across all devices
499+ in the group.
498500
499501 Example:
500502 ```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
546548 SameOperandsAndResultShape]> {
547549 let summary = "All-reduce over a device grid.";
548550 let description = [{
549- The accumulation element type is specified by the result type and
550- it does not need to match the input element type.
551- The input element is converted to the result element type before
552- performing the reduction.
551+ Reduces the input tensor across all devices within the groups defined by
552+ `grid_axes`, using the specified reduction method. The operation performs an
553+ element-wise reduction over the tensor slices from all devices in each group.
554+ Each device in a group receives a replicated copy of the reduction result.
555+ The accumulation element type is determined by the result type and does not
556+ need to match the input element type. Before performing the reduction, each
557+ input element is converted to the result element type.
553558
554559 Attributes:
555560 `reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
583588 SameOperandsAndResultElementType,
584589 SameOperandsAndResultRank
585590 ]> {
586- let summary = "All-slice over a device grid. This is the inverse of all-gather. ";
591+ let summary = "All-slice over a device grid.";
587592 let description = [{
588- Slice along the `slice_axis` tensor axis.
589- This operation can be thought of as the inverse of all-gather.
590- Technically, it is not required that all processes have the same input tensor.
591- Each process will slice a piece of its local tensor based on its in-group device index.
592- The operation does not communicate data between devices.
593+ Within each device group defined by `grid_axes`, slices the input tensor along
594+ the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
595+ the input data is replicated along the `slice_axis`.
596+ Each process simply crops its local data to the slice corresponding to its
597+ in-group device index.
598+ Notice: `AllSliceOp` does not involve any communication between devices and
599+ devices within a group may not have replicated input data.
593600
594601 Example:
595602 ```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
610617 ```
611618 Result:
612619 ```
613- gather tensor
620+ slice tensor
614621 axis 1
615622 ------------>
616623 +-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
646653 SameOperandsAndResultRank]> {
647654 let summary = "All-to-all over a device grid.";
648655 let description = [{
649- Performs an all-to-all on tensor pieces split along `split_axis`.
650- The resulting pieces are concatenated along `concat_axis` on ech device.
656+ Each participant logically splits its input along split_axis,
657+ then scatters the resulting pieces across the group defined by `grid_axes`.
658+ After receiving data pieces from other participants' scatters,
659+ it concatenates them along concat_axis to produce the final result.
651660
652661 Example:
653662 ```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
702711 ]> {
703712 let summary = "Broadcast over a device grid.";
704713 let description = [{
705- Broadcast the tensor on `root` to all devices in each respective group.
706- The operation broadcasts along grid axes `grid_axes`.
707- The `root` device specifies the in-group multi-index that is broadcast to
708- all other devices in the group.
714+ Copies the input tensor on `root` to all devices in each group defined by
715+ `grid_axes`. The `root` device is defined by its in-group multi-index.
716+ The contents of input tensors on non-root devices are ignored.
709717
710718 Example:
711719 ```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
722730 +-------+-------+ | broadcast
723731 device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
724732 +-------+-------+ ↓
725- device (1, 0) -> | | | <- device (1, 1)
733+ device (1, 0) -> | * * | * * | <- device (1, 1)
726734 +-------+-------+
727735 ```
728736
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
758766 ]> {
759767 let summary = "Gather over a device grid.";
760768 let description = [{
761- Gathers on device `root` along the `gather_axis` tensor axis.
762- `root` specifies the coordinates of a device along `grid_axes`.
763- It uniquely identifies the root device for each device group.
764- The result tensor on non-root devices is undefined.
765- Using it will result in undefined behavior.
769+ Concatenates all tensor slices from a device group defined by `grid_axes` along
770+ the tensor dimension `gather_axis` and returns the resulting tensor on each
771+ `root` device. The result on all other (non-root) devices is undefined.
772+ The `root` device is defined by its in-group multi-index.
766773
767774 Example:
768775 ```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
821828 ]> {
822829 let summary = "Send over a device grid.";
823830 let description = [{
824- Receive from a device within a device group.
831+ Receive tensor from device `source`, which is defined by its in-group
832+ multi-index. The groups are defined by `grid_axes`.
833+ The content of input tensor is ignored.
825834 }];
826835 let arguments = !con(commonArgs, (ins
827836 AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
845854 ]> {
846855 let summary = "Reduce over a device grid.";
847856 let description = [{
848- Reduces on device `root` within each device group.
849- `root` specifies the coordinates of a device along `grid_axes`.
850- It uniquely identifies the root device within its device group.
851- The accumulation element type is specified by the result type and
852- it does not need to match the input element type.
853- The input element is converted to the result element type before
854- performing the reduction.
857+ Reduces the input tensor across all devices within the groups defined by
858+ `grid_axes`, using the specified reduction method. The operation performs an
859+ element-wise reduction over the tensor slices from all devices in each group.
860+ The reduction result will be returned on the `root` device of each group.
861+ It is undefined on all other (non-root) devices.
862+ The `root` device is defined by its in-group multi-index.
863+ The accumulation element type is determined by the result type and does not
864+ need to match the input element type. Before performing the reduction, each
865+ input element is converted to the result element type.
855866
856867 Attributes:
857868 `reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
886897 SameOperandsAndResultRank]> {
887898 let summary = "Reduce-scatter over a device grid.";
888899 let description = [{
889- After the reduction, the result is scattered within each device group.
890- The tensor is split along `scatter_axis` and the pieces distributed
891- across the device group.
900+ Reduces the input tensor across all devices within the groups defined by
901+ `grid_axes` using the specified reduction method. The reduction is performed
902+ element-wise across the tensor pieces from all devices in the group.
903+ After reduction, the reduction result is scattered (split and distributed)
904+ across the device group along `scatter_axis`.
892905 Example:
893906 ```
894907 shard.grid @grid0(shape = 2x2)
895908 ...
896909 %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
897910 reduction = <max> scatter_axis = 0
898- : tensor<3x4xf32 > -> tensor<1x4xf64 >
911+ : tensor<2x2xf32 > -> tensor<1x2xf64 >
899912 ```
900913 Input:
901914 ```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
916929 Result:
917930 ```
918931 +-------+
919- | 6 8 | <- devices (0, 0)
932+ | 5 6 | <- devices (0, 0)
920933 +-------+
921- | 10 12 | <- devices (0, 1)
934+ | 7 8 | <- devices (0, 1)
922935 +-------+
923- | 22 24 | <- devices (1, 0)
936+ | 13 14 | <- devices (1, 0)
924937 +-------+
925- | 26 28 | <- devices (1, 1)
938+ | 15 16 | <- devices (1, 1)
926939 +-------+
927940 ```
928941 }];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
950963 ]> {
951964 let summary = "Scatter over a device grid.";
952965 let description = [{
953- For each device group split the input tensor on the `root` device along
954- axis `scatter_axis` and scatter the parts across the group devices.
966+ For each device group defined by `grid_axes`, the input tensor on the `root`
967+ device is split along axis `scatter_axis` and distributed across the group.
968+ The content of the input on all other (non-root) devices is ignored.
969+ The `root` device is defined by its in-group multi-index.
955970
956971 Example:
957972 ```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
968983 (0, 1)
969984 ↓
970985 +-------+-------+ | scatter tensor
971- device (0, 0) -> | | | | axis 0
972- | | | ↓
986+ device (0, 0) -> | * * | * * | | axis 0
987+ | * * | * * | ↓
973988 +-------+-------+
974989 device (1, 0) -> | 1 2 | 5 6 |
975990 | 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
10181033 ]> {
10191034 let summary = "Send over a device grid.";
10201035 let description = [{
1021- Send from one device to another within a device group.
1036+ Send input tensor to device `destination`, which is defined by its in-group
1037+ multi-index. The groups are defined by `grid_axes`.
10221038 }];
10231039 let arguments = !con(commonArgs, (ins
10241040 AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
10431059 ]> {
10441060 let summary = "Shift over a device grid.";
10451061 let description = [{
1046- Within each device group shift along grid axis `shift_axis` by an offset
1047- `offset`.
1048- The result on devices that do not have a corresponding source is undefined.
1049- `shift_axis` must be one of `grid_axes`.
1050- If the `rotate` attribute is present,
1051- instead of a shift a rotation is done.
1062+ Within each device group defined by `grid_axes`, shifts input tensors along the
1063+ device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
1064+ be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
1065+ That is, the offset wraps around according to the group size along `shift_axis`.
1066+ Otherwise, the results on devices without a corresponding source are undefined.
10521067
10531068 Example:
10541069 ```
0 commit comments