Skip to content

Conversation

@fschlimb
Copy link
Contributor

This PR seeks to improve the clarity of the Shard dialect documentation, in particular the descriptions of the communication operations.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves the clarity and consistency of documentation for the Shard dialect's collective communication operations. The changes focus on making operation descriptions more precise and easier to understand for users.

Key changes:

  • Standardized terminology for describing device groups, roots, and indices across all operations
  • Enhanced operation descriptions with clearer explanations of data flow and behavior
  • Fixed documentation examples to match actual operation semantics

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
mlir/include/mlir/Dialect/Shard/IR/ShardOps.td Updated documentation for collective communication operations (all_gather, all_reduce, all_slice, all_to_all, broadcast, gather, recv, reduce, reduce_scatter, scatter, send, shift) with clearer descriptions and corrected examples
mlir/docs/Dialects/Shard.md Refined explanation of device groups with more concise wording

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

This PR seeks to improve the clarity of the Shard dialect documentation, in particular the descriptions of the communication operations.


Full diff: https://github.com/llvm/llvm-project/pull/163782.diff

2 Files Affected:

  • (modified) mlir/docs/Dialects/Shard.md (+3-3)
  • (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+66-51)
diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
index eb6ff6150e474..573e888e6541f 100644
--- a/mlir/docs/Dialects/Shard.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -27,9 +27,9 @@ the tensor is sharded - not specified manually.
 
 ### Device Groups
 
-Each collective operation runs within a group of devices. You define groups
-using the `grid` and `grid_axes` attributes, which describe how to slice the
-full device grid into smaller groups.
+Collective operations run within groups of devices, which are defined
+using the `grid` and `grid_axes` attributes. These describe
+how the full device grid is sliced into smaller groups.
 
 Devices that have the same coordinates *outside* the listed `grid_axes` belong
 to the same group.
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index b9d7163ea4c1e..60461b9ddc826 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
   ]> {
   let summary = "All-gather over a device grid.";
   let description = [{
-    Gathers along the `gather_axis` tensor axis.
+    Concatenates all tensor slices from a device group defined by `grid_axes` along
+    the tensor dimension `gather_axis` and replicates the result across all devices
+    in the group.
 
     Example:
     ```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
     SameOperandsAndResultShape]> {
   let summary = "All-reduce over a device grid.";
   let description = [{
-    The accumulation element type is specified by the result type and
-    it does not need to match the input element type.
-    The input element is converted to the result element type before
-    performing the reduction.
+    Reduces the input tensor across all devices within the groups defined by
+    `grid_axes`, using the specified reduction method. The operation performs an
+    element-wise reduction over the tensor slices from all devices in each group.
+    Each device in a group receives a replicated copy of the reduction result.
+    The accumulation element type is determined by the result type and does not
+    need to match the input element type. Before performing the reduction, each
+    input element is converted to the result element type.
 
     Attributes:
     `reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank
   ]> {
-  let summary = "All-slice over a device grid. This is the inverse of all-gather.";
+  let summary = "All-slice over a device grid.";
   let description = [{
-    Slice along the `slice_axis` tensor axis.
-    This operation can be thought of as the inverse of all-gather.
-    Technically, it is not required that all processes have the same input tensor.
-    Each process will slice a piece of its local tensor based on its in-group device index.
-    The operation does not communicate data between devices. 
+    Within each device group defined by `grid_axes`, slices the input tensor along
+    the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
+    the input data is replicated along the `slice_axis`.
+    Each process simply crops its local data to the slice corresponding to its
+    in-group device index.
+    Notice: `AllSliceOp` does not involve any communication between devices and
+            devices within a group may not have replicated input data.
 
     Example:
     ```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
     ```
     Result:
     ```
-    gather tensor
+    slice tensor
     axis 1
     ------------>
                      +-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
     SameOperandsAndResultRank]> {
   let summary = "All-to-all over a device grid.";
   let description = [{
-    Performs an all-to-all on tensor pieces split along `split_axis`.
-    The resulting pieces are concatenated along `concat_axis` on ech device.
+    Each participant logically splits its input along split_axis,
+    then scatters the resulting pieces across the group defined by `grid_axes`.
+    After receiving data pieces from other participants' scatters,
+    it concatenates them along concat_axis to produce the final result.
 
     Example:
     ```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
   ]> {
   let summary = "Broadcast over a device grid.";
   let description = [{
-    Broadcast the tensor on `root` to all devices in each respective group.
-    The operation broadcasts along grid axes `grid_axes`.
-    The `root` device specifies the in-group multi-index that is broadcast to
-    all other devices in the group.
+    Copies the input tensor on `root` to the all devices in each group defined by
+    `grid_axes`. The `root` device is defined by its in-group multi-index.
+    The contents of input tensors on non-root devices are ignored.
     
     Example:
     ```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
                      +-------+-------+                   | broadcast
     device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                      +-------+-------+                   ↓
-    device (1, 0) -> |       |       | <- device (1, 1) 
+    device (1, 0) -> |  *  * |  *  * | <- device (1, 1)
                      +-------+-------+
     ```
 
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
   ]> {
   let summary = "Gather over a device grid.";
   let description = [{
-    Gathers on device `root` along the `gather_axis` tensor axis.
-    `root` specifies the coordinates of a device along `grid_axes`.
-    It uniquely identifies the root device for each device group.
-    The result tensor on non-root devices is undefined.
-    Using it will result in undefined behavior.
+    Concatenates all tensor slices from a device group defined by `grid_axes` along
+    the tensor dimension `gather_axis` and returns the resulting tensor on each
+    `root` device. The result on all other (non-root) devices is undefined.
+    The `root` device is defined by its in-group multi-index.
 
     Example:
     ```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
   ]> {
   let summary = "Send over a device grid.";
   let description = [{
-    Receive from a device within a device group.
+    Receive tensor from device `source`, which is defined by its in-group
+    multi-index. The groups are defined by `grid_axes`.
+    The content of input tensor is ignored.
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
   ]> {
   let summary = "Reduce over a device grid.";
   let description = [{
-    Reduces on device `root` within each device group.
-    `root` specifies the coordinates of a device along `grid_axes`.
-    It uniquely identifies the root device within its device group.
-    The accumulation element type is specified by the result type and
-    it does not need to match the input element type.
-    The input element is converted to the result element type before
-    performing the reduction.
+    Reduces the input tensor across all devices within the groups defined by
+    `grid_axes`, using the specified reduction method. The operation performs an
+    element-wise reduction over the tensor slices from all devices in each group.
+    The reduction result will be returned on the `root` device of each group.
+    It is undefined on all other (non-root) devices.
+    The `root` device is defined by its in-group multi-index.
+    The accumulation element type is determined by the result type and does not
+    need to match the input element type. Before performing the reduction, each
+    input element is converted to the result element type.
 
     Attributes:
     `reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     SameOperandsAndResultRank]> {
   let summary = "Reduce-scatter over a device grid.";
   let description = [{
-    After the reduction, the result is scattered within each device group.
-    The tensor is split along `scatter_axis` and the pieces distributed
-    across the device group.
+    Reduces the input tensor across all devices within the groups defined by
+    `grid_axes` using the specified reduction method. The reduction is performed
+    element-wise across the tensor pieces from all devices in the group.
+    After reduction, the reduction result is scattered (split and distributed)
+    across the device group along `scatter_axis`.
     Example:
     ```
     shard.grid @grid0(shape = 2x2)
     ...
     %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
       reduction = <max> scatter_axis = 0
-      : tensor<3x4xf32> -> tensor<1x4xf64>
+      : tensor<2x2xf32> -> tensor<1x2xf64>
     ```
     Input:
     ```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     Result:
     ```
     +-------+
-    |  6  8 | <- devices (0, 0)
+    |  5  6 | <- devices (0, 0)
     +-------+
-    | 10 12 | <- devices (0, 1)
+    |  7  8 | <- devices (0, 1)
     +-------+
-    | 22 24 | <- devices (1, 0)
+    | 13 14 | <- devices (1, 0)
     +-------+
-    | 26 28 | <- devices (1, 1)
+    | 15 16 | <- devices (1, 1)
     +-------+
     ```
   }];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   ]> {
   let summary = "Scatter over a device grid.";
   let description = [{
-    For each device group split the input tensor on the `root` device along
-    axis `scatter_axis` and scatter the parts across the group devices.
+    For each device group defined by `grid_axes`, the input tensor on the `root`
+    device is split along axis `scatter_axis` and distributed across the group.
+    The content of the input on all other (non-root) devices is ignored.
+    The `root` device is defined by its in-group multi-index.
 
     Example:
     ```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
                               (0, 1)
                                  ↓
                      +-------+-------+  | scatter tensor
-    device (0, 0) -> |       |       |  | axis 0
-                     |       |       |  ↓
+    device (0, 0) -> |  *  * |  *  * |  | axis 0
+                     |  *  * |  *  * |  ↓
                      +-------+-------+
     device (1, 0) -> |  1  2 |  5  6 |
                      |  3  4 |  7  8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
   ]> {
   let summary = "Send over a device grid.";
   let description = [{
-    Send from one device to another within a device group.
+    Send input tensor to device `destination`, which is defined by its in-group
+    multi-index. The groups are defined by `grid_axes`.
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
   ]> {
   let summary = "Shift over a device grid.";
   let description = [{
-    Within each device group shift along grid axis `shift_axis` by an offset
-    `offset`.
-    The result on devices that do not have a corresponding source is undefined.
-    `shift_axis` must be one of `grid_axes`.
-    If the `rotate` attribute is present,
-    instead of a shift a rotation is done.
+    Within each device group defined by `grid_axes`, shifts input tensors along the
+    device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
+    be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
+    That is, the offset wraps around according to the group size along `shift_axis`.
+    Otherwise, the results on devices without a corresponding source are undefined.
 
     Example:
     ```

@fschlimb fschlimb merged commit d74c976 into llvm:main Oct 17, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants