Skip to content

Commit d4405ad

Browse files
authored
Collectives Ops : Match example from the StableHLO op description comment with example from the spec (#2481)
as per https://github.com/openxla/stablehlo/blob/main/docs/reference_checklist.md#after-implementing-the-op.
1 parent 3ec5546 commit d4405ad

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

docs/spec.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,9 @@ Afterwards, within each `process_group`:
829829
"stablehlo.return"(%0) : (tensor<i64>) -> ()
830830
}) {
831831
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
832+
// channel_id = 0
832833
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
834+
// use_global_device_ids = false
833835
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
834836
// %result0@(0, 0): [6, 8, 10, 12]
835837
// %result0@(1, 0): [6, 8, 10, 12]
@@ -918,6 +920,7 @@ Afterwards, within each `process_group`:
918920
concat_dimension = 0 : i64,
919921
split_count = 2 : i64,
920922
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
923+
// channel_id = 0
921924
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
922925
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
923926
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]

stablehlo/dialect/StablehloOps.td

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,11 +1356,11 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
13561356

13571357
Example:
13581358
```mlir
1359-
%result = "stablehlo.all_gather"(%operand) {
1359+
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
13601360
all_gather_dim = 1 : i64,
13611361
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
13621362
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
1363-
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
1363+
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
13641364
```
13651365
}];
13661366

@@ -1395,15 +1395,14 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
13951395

13961396
Example:
13971397
```mlir
1398-
%result = "stablehlo.all_reduce"(%operand) ({
1398+
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
13991399
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
1400-
%0 = stablehlo.add %arg1, %arg2 : tensor<i64>
1401-
stablehlo.return %0 : tensor<i64>
1400+
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
1401+
"stablehlo.return"(%0) : (tensor<i64>) -> ()
14021402
}) {
1403-
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
1403+
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
14041404
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
1405-
// use_global_device_ids = false
1406-
} : (tensor<4xi64>) -> tensor<4xi64>
1405+
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
14071406
```
14081407
}];
14091408

@@ -1483,12 +1482,12 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
14831482

14841483
Example:
14851484
```mlir
1486-
%result = "stablehlo.all_to_all"(%operand) {
1485+
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
14871486
split_dimension = 1 : i64,
14881487
concat_dimension = 0 : i64,
14891488
split_count = 2 : i64,
14901489
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
1491-
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
1490+
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
14921491
```
14931492
}];
14941493

0 commit comments

Comments
 (0)