@@ -156,6 +156,40 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
156
156
];
157
157
}
158
158
159
+ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
160
+ Pure,
161
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
162
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
163
+ ]> {
164
+ let summary =
165
+ "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
166
+ let description = [{
167
+ Example:
168
+ ```
169
+ mesh.mesh @mesh0(shape = 10x20x30)
170
+ %c1 = arith.constant 1 : index
171
+ %c2 = arith.constant 2 : index
172
+ %c3 = arith.constant 3 : index
173
+ %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
174
+ ```
175
+ The above returns two indices, `633` and `693`, which correspond to the
176
+ index of the previous process `(1, 1, 3)`, and the next process
177
+ `(1, 3, 3) along the split axis `1`.
178
+
179
+ A negative value is returned if there is no neighbor in the respective
180
+ direction along the given `split_axes`.
181
+ }];
182
+ let arguments = (ins FlatSymbolRefAttr:$mesh,
183
+ Variadic<Index>:$device,
184
+ Mesh_MeshAxesAttr:$split_axes);
185
+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
186
+ let assemblyFormat = [{
187
+ `on` $mesh `[` $device `]`
188
+ `split_axes` `=` $split_axes
189
+ attr-dict `:` type(results)
190
+ }];
191
+ }
192
+
159
193
//===----------------------------------------------------------------------===//
160
194
// Sharding operations.
161
195
//===----------------------------------------------------------------------===//
@@ -1058,12 +1092,12 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
1058
1092
}
1059
1093
1060
1094
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1095
+ Pure,
1061
1096
DestinationStyleOpInterface,
1062
1097
TypesMatchWith<
1063
1098
"result has same type as destination",
1064
1099
"result", "destination", "$_self">,
1065
- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066
- AttrSizedOperandSegments
1100
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
1067
1101
]> {
1068
1102
let summary = "Update halo data.";
1069
1103
let description = [{
@@ -1072,7 +1106,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1072
1106
on the remote devices. Changes might be caused by mutating operations
1073
1107
and/or if the new halo regions are larger than the existing ones.
1074
1108
1075
- Source and destination might have different halo sizes .
1109
+ Destination is supposed to be initialized with the local data (not halos) .
1076
1110
1077
1111
Assumes all devices hold tensors with same-sized halo data as specified
1078
1112
by `source_halo_sizes/static_source_halo_sizes` and
@@ -1084,25 +1118,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1084
1118
1085
1119
}];
1086
1120
let arguments = (ins
1087
- AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
1088
1121
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
1089
1122
FlatSymbolRefAttr:$mesh,
1090
1123
Mesh_MeshAxesArrayAttr:$split_axes,
1091
- Variadic<I64>:$source_halo_sizes,
1092
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093
- Variadic<I64>:$destination_halo_sizes,
1094
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1124
+ Variadic<I64>:$halo_sizes,
1125
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
1095
1126
);
1096
1127
let results = (outs
1097
1128
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
1098
1129
);
1099
1130
let assemblyFormat = [{
1100
- $source `into` $ destination
1131
+ $destination
1101
1132
`on` $mesh
1102
1133
`split_axes` `=` $split_axes
1103
- (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104
- (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105
- attr-dict `:` type($source) `->` type($result)
1134
+ (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
1135
+ attr-dict `:` type($result)
1106
1136
}];
1107
1137
let extraClassDeclaration = [{
1108
1138
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
0 commit comments