@@ -156,6 +156,40 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
156156 ];
157157}
158158
159+ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
160+ Pure,
161+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
162+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
163+ ]> {
164+ let summary =
165+ "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
166+ let description = [{
167+ Example:
168+ ```
169+ mesh.mesh @mesh0(shape = 10x20x30)
170+ %c1 = arith.constant 1 : index
171+ %c2 = arith.constant 2 : index
172+ %c3 = arith.constant 3 : index
173+ %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
174+ ```
175+ The above returns two indices, `633` and `693`, which correspond to the
176+ index of the previous process `(1, 1, 3)`, and the next process
177+ `(1, 3, 3) along the split axis `1`.
178+
179+ A negative value is returned if there is no neighbor in the respective
180+ direction along the given `split_axes`.
181+ }];
182+ let arguments = (ins FlatSymbolRefAttr:$mesh,
183+ Variadic<Index>:$device,
184+ Mesh_MeshAxesAttr:$split_axes);
185+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
186+ let assemblyFormat = [{
187+ `on` $mesh `[` $device `]`
188+ `split_axes` `=` $split_axes
189+ attr-dict `:` type(results)
190+ }];
191+ }
192+
159193//===----------------------------------------------------------------------===//
160194// Sharding operations.
161195//===----------------------------------------------------------------------===//
@@ -1058,12 +1092,12 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
10581092}
10591093
10601094def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1095+ Pure,
10611096 DestinationStyleOpInterface,
10621097 TypesMatchWith<
10631098 "result has same type as destination",
10641099 "result", "destination", "$_self">,
1065- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066- AttrSizedOperandSegments
1100+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
10671101]> {
10681102 let summary = "Update halo data.";
10691103 let description = [{
@@ -1072,7 +1106,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10721106 on the remote devices. Changes might be caused by mutating operations
10731107 and/or if the new halo regions are larger than the existing ones.
10741108
1075- Source and destination might have different halo sizes .
1109+ Destination is supposed to be initialized with the local data (not halos) .
10761110
10771111 Assumes all devices hold tensors with same-sized halo data as specified
10781112 by `source_halo_sizes/static_source_halo_sizes` and
@@ -1084,25 +1118,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10841118
10851119 }];
10861120 let arguments = (ins
1087- AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
10881121 AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
10891122 FlatSymbolRefAttr:$mesh,
10901123 Mesh_MeshAxesArrayAttr:$split_axes,
1091- Variadic<I64>:$source_halo_sizes,
1092- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093- Variadic<I64>:$destination_halo_sizes,
1094- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1124+ Variadic<I64>:$halo_sizes,
1125+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
10951126 );
10961127 let results = (outs
10971128 AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
10981129 );
10991130 let assemblyFormat = [{
1100- $source `into` $ destination
1131+ $destination
11011132 `on` $mesh
11021133 `split_axes` `=` $split_axes
1103- (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104- (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105- attr-dict `:` type($source) `->` type($result)
1134+ (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
1135+ attr-dict `:` type($result)
11061136 }];
11071137 let extraClassDeclaration = [{
11081138 MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
0 commit comments