Skip to content

Commit 27c8373

Browse files
committed
formatting
1 parent e901c39 commit 27c8373

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,21 +521,25 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
521521

522522
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523523
auto ctx = kind.getContext();
524+
auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
525+
return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
526+
};
527+
524528
switch (kind.getValue()) {
525529
case ReductionKind::Sum:
526-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
530+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
527531
case ReductionKind::Product:
528-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
532+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
529533
case ReductionKind::Min:
530-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
534+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
531535
case ReductionKind::Max:
532-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
536+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
533537
case ReductionKind::BitwiseAnd:
534-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
538+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
535539
case ReductionKind::BitwiseOr:
536-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
540+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
537541
case ReductionKind::BitwiseXor:
538-
return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
542+
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
539543
default:
540544
assert(false && "Unknown/unsupported reduction kind");
541545
}
@@ -630,7 +634,8 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
630634

631635
// If the destination is a memref, cast it to a tensor
632636
if (isa<RankedTensorType>(op.getType()))
633-
buffer = iBuilder.create<bufferization::ToTensorOp>(buffer, true);
637+
buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer,
638+
true);
634639

635640
rewriter.replaceOp(op, buffer);
636641
return success();
@@ -908,7 +913,7 @@ struct ConvertMeshToMPIPass
908913

909914
// No mesh dialect should left after conversion...
910915
target.addIllegalDialect<mesh::MeshDialect>();
911-
// ...except the global MeshOp. MeshShapeOp which will get folded separately.
916+
// ...except the global MeshOp. MeshShapeOp which will get folded later.
912917
target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
913918
// Allow all the stuff that our patterns will convert to
914919
target.addLegalDialect<

0 commit comments

Comments
 (0)