Skip to content

Commit c48c7f0

Browse files
committed
fixing comments
1 parent 1ad7725 commit c48c7f0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,16 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
7272
return linearIndex;
7373
}
7474

75-
// This pattern converts the mesh.update_halo operation to MPI calls
7675
struct ConvertProcessMultiIndexOp
7776
: public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
7877
using OpRewritePattern::OpRewritePattern;
7978

8079
mlir::LogicalResult
8180
matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
8281
mlir::PatternRewriter &rewriter) const override {
82+
83+
// Currently converts its linear index to a multi-dimensional index.
84+
8385
SymbolTableCollection symbolTableCollection;
8486
auto loc = op.getLoc();
8587
auto meshOp = getMesh(op, symbolTableCollection);
@@ -112,16 +114,17 @@ struct ConvertProcessMultiIndexOp
112114
}
113115
};
114116

115-
// This pattern converts the mesh.update_halo operation to MPI calls.
116-
// If it finds a global named "static_mpi_rank" it will use that splat value.
117-
// Otherwise it defaults to mpi.comm_rank.
118117
struct ConvertProcessLinearIndexOp
119118
: public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
120119
using OpRewritePattern::OpRewritePattern;
121120

122121
mlir::LogicalResult
123122
matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
124123
mlir::PatternRewriter &rewriter) const override {
124+
125+
// Finds a global named "static_mpi_rank" it will use that splat value.
126+
// Otherwise it defaults to mpi.comm_rank.
127+
125128
auto loc = op.getLoc();
126129
auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
127130
if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
@@ -145,14 +148,18 @@ struct ConvertProcessLinearIndexOp
145148
}
146149
};
147150

148-
// This pattern converts the mesh.update_halo operation to MPI calls
149151
struct ConvertNeighborsLinearIndicesOp
150152
: public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
151153
using OpRewritePattern::OpRewritePattern;
152154

153155
mlir::LogicalResult
154156
matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
155157
mlir::PatternRewriter &rewriter) const override {
158+
159+
// Computes the neighbors indices along a split axis by simply
160+
// adding/subtracting 1 to the current index in that dimension.
161+
// Assigns -1 if neighbor is out of bounds.
162+
156163
auto axes = op.getSplitAxes();
157164
// For now only single axis sharding is supported
158165
if (axes.size() != 1) {
@@ -209,14 +216,14 @@ struct ConvertNeighborsLinearIndicesOp
209216
}
210217
};
211218

212-
// This pattern converts the mesh.update_halo operation to MPI calls
213219
struct ConvertUpdateHaloOp
214220
: public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
215221
using OpRewritePattern::OpRewritePattern;
216222

217223
mlir::LogicalResult
218224
matchAndRewrite(mlir::mesh::UpdateHaloOp op,
219225
mlir::PatternRewriter &rewriter) const override {
226+
220227
// The input/output memref is assumed to be in C memory order.
221228
// Halos are exchanged as 2 blocks per dimension (one for each side: down
222229
// and up). For each haloed dimension `d`, the exchanged blocks are

0 commit comments

Comments
 (0)