Skip to content

Conversation

@fschlimb
Copy link
Contributor

Adding lowering mesh.allreduce to mpi.allreduce.
Minor restructuring to increase code reuse.

@llvmbot llvmbot added the mlir label Jun 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

Adding lowering mesh.allreduce to mpi.allreduce.
Minor restructuring to increase code reuse.


Patch is 63.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144060.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+2)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.h (+1)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+6-4)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+5-5)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+132-39)
  • (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+38)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+25)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+16-6)
  • (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+207-150)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+30)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..5a864865adffc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
     shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
+    "affine::AffineDialect",
+    "arith::ArithDialect",
     "memref::MemRefDialect",
     "mpi::MPIDialect",
     "scf::SCFDialect",
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
index f06b911ce3fe3..2b6743cd008c6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
 // MPIDialect
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d78aa92d201e7..c14837f6961eb 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/MPI/IR/MPI.td"
 include "mlir/Dialect/MPI/IR/MPITypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 class MPI_Op<string mnemonic, list<Trait> traits = []>
     : Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   );
 
   let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f59c4c4c67517..ac05ee243d7be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyRankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
-    AnyRankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index c64da29ca6412..3f1041cb25103 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
   auto isEndomorphismOp = [reduction](Operation *op,
                                       std::optional<Operation *> referenceOp) {
     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
-    if (!allReduceOp ||
-        allReduceOp.getInput().getType().getElementType() !=
-            allReduceOp.getResult().getType().getElementType() ||
+    auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
+    auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
+    if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
         allReduceOp.getReduction() != reduction) {
       return false;
     }
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
     }
 
     auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
-           allReduceOp.getInput().getType().getElementType() ==
-               refAllReduceOp.getInput().getType().getElementType();
+           inType.getElementType() == refType.getElementType();
   };
   auto isAlgebraicOp = [](Operation *op) {
     return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index be82e2af399dc..5a1154bf9166e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -42,6 +42,10 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
                                                ArrayRef<MeshAxis> meshAxes,
                                                ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+                         ArrayRef<MeshAxis> meshAxes,
+                         ImplicitLocOpBuilder &builder);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 823d4d644f586..521569e69b61a 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -12,9 +12,9 @@
 
 #include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +22,8 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
 
 class ConvertProcessLinearIndexOp
     : public OpConversionPattern<ProcessLinearIndexOp> {
-  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
 
 public:
   using OpConversionPattern::OpConversionPattern;
 
-  // Constructor accepting worldRank
-  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
-                              MLIRContext *context, int64_t worldRank = -1)
-      : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
-
   LogicalResult
   matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
+    // Create mpi::CommRankOp
     Location loc = op.getLoc();
-    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
-      return success();
-    }
-
-    // Otherwise call create mpi::CommRankOp
     auto ctx = op.getContext();
     Value commWorld =
         rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
   }
 };
 
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+  auto ctx = kind.getContext();
+  switch (kind.getValue()) {
+  case ReductionKind::Sum:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+  case ReductionKind::Product:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+  case ReductionKind::Min:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+  case ReductionKind::Max:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+  case ReductionKind::BitwiseAnd:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+  case ReductionKind::BitwiseOr:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+  case ReductionKind::BitwiseXor:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+  default:
+    assert(false && "Unknown/unsupported reduction kind");
+  }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    auto mesh = adaptor.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    if (!meshOp)
+      return op->emitError() << "No mesh found for AllReduceOp";
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return op->emitError()
+             << "Dynamic mesh shape not supported in AllReduceOp";
+
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = adaptor.getInput();
+    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+    // If the source is a memref, cast it to a tensor.
+    if (isa<RankedTensorType>(input.getType())) {
+      auto memrefType = MemRefType::get(
+          inputShape, cast<ShapedType>(input.getType()).getElementType());
+      input = iBuilder.create<bufferization::ToMemrefOp>(memrefType, input);
+    }
+    MemRefType inType = cast<MemRefType>(input.getType());
+
+    // Get the actual shape to allocate the buffer.
+    SmallVector<OpFoldResult> shape(inType.getRank());
+    for (auto i = 0; i < inType.getRank(); ++i) {
+      auto s = inputShape[i];
+      if (ShapedType::isDynamic(s))
+        shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+      else
+        shape[i] = iBuilder.getIndexAttr(s);
+    }
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = iBuilder.create<memref::AllocOp>(
+        shape, cast<ShapedType>(op.getType()).getElementType());
+    iBuilder.create<linalg::CopyOp>(input, buffer);
+
+    // Get an MPI_Comm_split for the AllReduce operation.
+    // The color is the linear index of the process in the mesh along the
+    // non-reduced axes. The key is the linear index of the process in the mesh
+    // along the reduced axes.
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       iBuilder.getIndexType());
+    SmallVector<Value> myMultiIndex =
+        iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+            .getResult();
+    Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+    auto redAxes = adaptor.getMeshAxes();
+    for (auto axis : redAxes) {
+      multiKey[axis] = myMultiIndex[axis];
+      myMultiIndex[axis] = zero;
+    }
+
+    Value color =
+        createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+    color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+    Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+    key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+
+    // Finally split the communicator
+    auto commType = mpi::CommType::get(op->getContext());
+    Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+    auto comm =
+        iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+            .getNewcomm();
+
+    Value buffer1d = buffer;
+    // Collapse shape to 1d if needed
+    if (inType.getRank() > 1) {
+      ReassociationIndices reassociation(inType.getRank());
+      std::iota(reassociation.begin(), reassociation.end(), 0);
+      buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+          buffer, ArrayRef<ReassociationIndices>(reassociation));
+    }
+
+    // Create the MPI AllReduce operation.
+    iBuilder.create<mpi::AllReduceOp>(
+        TypeRange(), buffer1d, buffer1d,
+        getMPIReduction(adaptor.getReductionAttr()), comm);
+
+    // If the destination is a memref, cast it to a tensor
+    if (isa<RankedTensorType>(op.getType()))
+      buffer = iBuilder.create<bufferization::ToTensorOp>(buffer, true);
+
+    rewriter.replaceOp(op, buffer);
+    return success();
+  }
+};
+
 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -573,10 +681,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     Value array = dest;
     if (isa<RankedTensorType>(array.getType())) {
       // If the destination is a memref, we need to cast it to a tensor
-      auto tensorType = MemRefType::get(
+      auto mmemrefType = MemRefType::get(
           dstShape, cast<ShapedType>(array.getType()).getElementType());
       array =
-          rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
+          rewriter.create<bufferization::ToMemrefOp>(loc, mmemrefType, array);
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -753,22 +861,6 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    uint64_t worldRank = -1;
-    // Try to get DLTI attribute for MPI:comm_world_rank
-    // If found, set worldRank to the value of the attribute.
-    {
-      auto dltiAttr =
-          dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
-      if (succeeded(dltiAttr)) {
-        if (!isa<IntegerAttr>(dltiAttr.value())) {
-          getOperation()->emitError()
-              << "Expected an integer attribute for MPI:comm_world_rank";
-          return signalPassFailure();
-        }
-        worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
-      }
-    }
-
     auto *ctxt = &getContext();
     RewritePatternSet patterns(ctxt);
     ConversionTarget target(getContext());
@@ -819,10 +911,10 @@ struct ConvertMeshToMPIPass
     // ...except the global MeshOp
     target.addLegalOp<mesh::MeshOp>();
     // Allow all the stuff that our patterns will convert to
-    target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
-                           arith::ArithDialect, tensor::TensorDialect,
-                           bufferization::BufferizationDialect,
-                           linalg::LinalgDialect, memref::MemRefDialect>();
+    target.addLegalDialect<
+        BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
+        tensor::TensorDialect, bufferization::BufferizationDialect,
+        linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
     // Make sure the function signature, calls etc. are legal
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -832,9 +924,10 @@ struct ConvertMeshToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertProcessMultiIndexOp, ConvertGetShardingOp,
-                 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
-    // ConvertProcessLinearIndexOp accepts an optional worldRank
-    patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+                 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
+                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+    SymbolTableCollection symbolTableCollection;
+    mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
         patterns, typeConverter);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 56d8edfbcc025..6d445ca0e4099 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
@@ -41,6 +42,38 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
     return mlir::success();
   }
 };
+
+struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
+  using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
+                                mlir::PatternRewriter &b) const override {
+    auto comm = op.getComm();
+    if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>()) {
+      return mlir::failure();
+    }
+
+    // Try to get DLTI attribute for MPI:comm_world_rank
+    // If found, set worldRank to the value of the attribute.
+    {
+      auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
+      if (failed(dltiAttr))
+        return mlir::failure();
+      if (!isa<IntegerAttr>(dltiAttr.value())) {
+        return op->emitError()
+               << "Expected an integer attribute for MPI:comm_world_rank";
+      }
+      Value res = b.create<arith::ConstantIndexOp>(
+          op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+      if (Value retVal = op.getRetval())
+        b.replaceOp(op, {retVal, res});
+      else
+        b.replaceOp(op, res);
+      return mlir::success();
+    }
+  }
+};
+
 } // namespace
 
 void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -63,6 +96,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
   results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
 }
 
+void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldRank>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304cb55a35086..b84de2b716b32 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -75,6 +75,31 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
   return lhs.value() * rhs.value();
 }
 
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
+                                                const Location &loc,
+                                                llvm::ArrayRef<int64_t> statics,
+                                                ValueRange dynamics,
+                                                Type ...
[truncated]

@rengolin
Copy link
Member

FYI @BenBrock

@fschlimb
Copy link
Contributor Author

fschlimb commented Jun 13, 2025

FYI @tkarna @mofeing

@fschlimb fschlimb requested a review from Dinistro June 13, 2025 13:58
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a bunch of comments. The main concern is mixing conversion and non-conversion patterns, which is broken in the general case.

Comment on lines +539 to +540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Drop the default. Clang and GCC complain about uncovered cases at compile time, which they cannot do when there is a default case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep it. Compiler checks can be disabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Consider to factor the enum conversion into a separate function, do avoid duplicating the attribute construction this often.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep it like this. Extra work introducing an indirection just to make the lines a few characters doesn't look right to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Use a more descriptive name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like what?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMPIReductionAttr? getMPIReductionOpAttr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is this unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in meshtompi.cpp. Another PR will add a use elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can cut some of the check lines a bit. They contain tons of uninteresting type information that is only a pain to maintain but gives almost no benefits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe here and there, but we'd need to keep some anyway, to check that collapsing works the right way etc.
I'll leave it as-is for now. If it becomes an issue, we can re-visit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a normal builder is not legal, you need to use the provided rewriter for all IR manipulations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I "copied" this type of use from ComplexToStandard. Otherwise, how would I be able to use these helper functions like indexResultTypes?

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMPIReductionAttr? getMPIReductionOpAttr?

@github-actions
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions h,cpp -- mlir/include/mlir/Dialect/MPI/IR/MPI.h mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp mlir/lib/Dialect/MPI/IR/MPIOps.cpp mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index aaf1d39d4..bbae8f32b 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -523,19 +523,26 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
   auto ctx = kind.getContext();
   switch (kind.getValue()) {
   case ReductionKind::Sum:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+                                             mpi::MPI_ReductionOpEnum::MPI_SUM);
   case ReductionKind::Product:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
+    return mpi::MPI_ReductionOpEnumAttr::get(
+        ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
   case ReductionKind::Min:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+                                             mpi::MPI_ReductionOpEnum::MPI_MIN);
   case ReductionKind::Max:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+                                             mpi::MPI_ReductionOpEnum::MPI_MAX);
   case ReductionKind::BitwiseAnd:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
+    return mpi::MPI_ReductionOpEnumAttr::get(
+        ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
   case ReductionKind::BitwiseOr:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+                                             mpi::MPI_ReductionOpEnum::MPI_BOR);
   case ReductionKind::BitwiseXor:
-    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
+    return mpi::MPI_ReductionOpEnumAttr::get(
+        ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
   default:
     assert(false && "Unknown/unsupported reduction kind");
   }

@fschlimb fschlimb mentioned this pull request Jun 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants