@@ -72,14 +72,16 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
7272 return linearIndex;
7373}
7474
75- // This pattern converts the mesh.update_halo operation to MPI calls
7675struct ConvertProcessMultiIndexOp
7776 : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
7877 using OpRewritePattern::OpRewritePattern;
7978
8079 mlir::LogicalResult
8180 matchAndRewrite (mlir::mesh::ProcessMultiIndexOp op,
8281 mlir::PatternRewriter &rewriter) const override {
82+
83+ // Currently converts its linear index to a multi-dimensional index.
84+
8385 SymbolTableCollection symbolTableCollection;
8486 auto loc = op.getLoc ();
8587 auto meshOp = getMesh (op, symbolTableCollection);
@@ -112,16 +114,17 @@ struct ConvertProcessMultiIndexOp
112114 }
113115};
114116
115- // This pattern converts the mesh.update_halo operation to MPI calls.
116- // If it finds a global named "static_mpi_rank" it will use that splat value.
117- // Otherwise it defaults to mpi.comm_rank.
118117struct ConvertProcessLinearIndexOp
119118 : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
120119 using OpRewritePattern::OpRewritePattern;
121120
122121 mlir::LogicalResult
123122 matchAndRewrite (mlir::mesh::ProcessLinearIndexOp op,
124123 mlir::PatternRewriter &rewriter) const override {
124+
125+ // Finds a global named "static_mpi_rank" it will use that splat value.
126+ // Otherwise it defaults to mpi.comm_rank.
127+
125128 auto loc = op.getLoc ();
126129 auto rankOpName = StringAttr::get (op->getContext (), " static_mpi_rank" );
127130 if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
@@ -145,14 +148,18 @@ struct ConvertProcessLinearIndexOp
145148 }
146149};
147150
148- // This pattern converts the mesh.update_halo operation to MPI calls
149151struct ConvertNeighborsLinearIndicesOp
150152 : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
151153 using OpRewritePattern::OpRewritePattern;
152154
153155 mlir::LogicalResult
154156 matchAndRewrite (mlir::mesh::NeighborsLinearIndicesOp op,
155157 mlir::PatternRewriter &rewriter) const override {
158+
159+ // Computes the neighbors indices along a split axis by simply
160+ // adding/subtracting 1 to the current index in that dimension.
161+ // Assigns -1 if neighbor is out of bounds.
162+
156163 auto axes = op.getSplitAxes ();
157164 // For now only single axis sharding is supported
158165 if (axes.size () != 1 ) {
@@ -209,14 +216,14 @@ struct ConvertNeighborsLinearIndicesOp
209216 }
210217};
211218
212- // This pattern converts the mesh.update_halo operation to MPI calls
213219struct ConvertUpdateHaloOp
214220 : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
215221 using OpRewritePattern::OpRewritePattern;
216222
217223 mlir::LogicalResult
218224 matchAndRewrite (mlir::mesh::UpdateHaloOp op,
219225 mlir::PatternRewriter &rewriter) const override {
226+
220227 // The input/output memref is assumed to be in C memory order.
221228 // Halos are exchanged as 2 blocks per dimension (one for each side: down
222229 // and up). For each haloed dimension `d`, the exchanged blocks are
0 commit comments