@@ -521,21 +521,25 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
521521
522522static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp (ReductionKindAttr kind) {
523523 auto ctx = kind.getContext ();
524+ auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
525+ return mpi::MPI_ReductionOpEnumAttr::get (ctx, redOp);
526+ };
527+
524528 switch (kind.getValue ()) {
525529 case ReductionKind::Sum:
526- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
530+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_SUM);
527531 case ReductionKind::Product:
528- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
532+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_PROD);
529533 case ReductionKind::Min:
530- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
534+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_MIN);
531535 case ReductionKind::Max:
532- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
536+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_MAX);
533537 case ReductionKind::BitwiseAnd:
534- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
538+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_BAND);
535539 case ReductionKind::BitwiseOr:
536- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
540+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_BOR);
537541 case ReductionKind::BitwiseXor:
538- return mpi::MPI_ReductionOpEnumAttr::get (ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
542+ return getReductionOp ( mpi::MPI_ReductionOpEnum::MPI_BXOR);
539543 default :
540544 assert (false && " Unknown/unsupported reduction kind" );
541545 }
@@ -630,7 +634,8 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
630634
631635 // If the destination is a memref, cast it to a tensor
632636 if (isa<RankedTensorType>(op.getType ()))
633- buffer = iBuilder.create <bufferization::ToTensorOp>(buffer, true );
637+ buffer = iBuilder.create <bufferization::ToTensorOp>(op.getType (), buffer,
638+ true );
634639
635640 rewriter.replaceOp (op, buffer);
636641 return success ();
@@ -908,7 +913,7 @@ struct ConvertMeshToMPIPass
908913
909914 // No mesh dialect should left after conversion...
910915 target.addIllegalDialect <mesh::MeshDialect>();
911- // ...except the global MeshOp. MeshShapeOp which will get folded separately .
916+ // ...except the global MeshOp. MeshShapeOp which will get folded later .
912917 target.addLegalOp <mesh::MeshOp, mesh::MeshShapeOp>();
913918 // Allow all the stuff that our patterns will convert to
914919 target.addLegalDialect <
0 commit comments