Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
shard/partition sizes depend on the rank.
}];
let dependentDialects = [
"affine::AffineDialect",
"arith::ArithDialect",
"memref::MemRefDialect",
"mpi::MPIDialect",
"scf::SCFDialect",
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

//===----------------------------------------------------------------------===//
// MPIDialect
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MPI/IR/MPI.td
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;

def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [
MPI_OpNull,
MPI_OpMax,
MPI_OpMin,
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/Dialect/MPI/IR/MPI.td"
include "mlir/Dialect/MPI/IR/MPITypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class MPI_Op<string mnemonic, list<Trait> traits = []>
: Op<MPI_Dialect, mnemonic, traits>;
Expand Down Expand Up @@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
// CommWorldOp
//===----------------------------------------------------------------------===//

def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
let description = [{
This operation returns the predefined MPI_COMM_WORLD communicator.
Expand All @@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
// CommRankOp
//===----------------------------------------------------------------------===//

def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
Expand All @@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
);

let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//

def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
Expand All @@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
// CommSplitOp
//===----------------------------------------------------------------------===//

def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
Expand Down Expand Up @@ -281,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassEnum : $op,
MPI_ReductionOpEnum : $op,
MPI_Comm : $comm
);

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);

/// Converts a vector of OpFoldResults (ints) into vector of Values of the
/// provided type.
SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is this unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in meshtompi.cpp. Another PR will add a use elsewhere.

llvm::ArrayRef<int64_t> statics,
ValueRange dynamics, Type type = Type());
} // namespace mesh
} // namespace mlir

Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
auto isEndomorphismOp = [reduction](Operation *op,
std::optional<Operation *> referenceOp) {
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
if (!allReduceOp ||
allReduceOp.getInput().getType().getElementType() !=
allReduceOp.getResult().getType().getElementType() ||
auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
allReduceOp.getReduction() != reduction) {
return false;
}
Expand All @@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}

auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
allReduceOp.getInput().getType().getElementType() ==
refAllReduceOp.getInput().getType().getElementType();
inType.getElementType() == refType.getElementType();
};
auto isAlgebraicOp = [](Operation *op) {
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder);
// Get process linear index from a multi-index along the given mesh axes .
TypedValue<IndexType>
createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder);

} // namespace mesh
} // namespace mlir
Expand Down
62 changes: 31 additions & 31 deletions mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class MPIImplTraits {
/// enum value.
virtual Value getMPIOp(const Location loc,
ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) = 0;
mpi::MPI_ReductionOpEnum opAttr) = 0;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -199,49 +199,49 @@ class MPICHImplTraits : public MPIImplTraits {
}

Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) override {
mpi::MPI_ReductionOpEnum opAttr) override {
int32_t op = MPI_NO_OP;
switch (opAttr) {
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
op = MPI_NO_OP;
break;
case mpi::MPI_OpClassEnum::MPI_MAX:
case mpi::MPI_ReductionOpEnum::MPI_MAX:
op = MPI_MAX;
break;
case mpi::MPI_OpClassEnum::MPI_MIN:
case mpi::MPI_ReductionOpEnum::MPI_MIN:
op = MPI_MIN;
break;
case mpi::MPI_OpClassEnum::MPI_SUM:
case mpi::MPI_ReductionOpEnum::MPI_SUM:
op = MPI_SUM;
break;
case mpi::MPI_OpClassEnum::MPI_PROD:
case mpi::MPI_ReductionOpEnum::MPI_PROD:
op = MPI_PROD;
break;
case mpi::MPI_OpClassEnum::MPI_LAND:
case mpi::MPI_ReductionOpEnum::MPI_LAND:
op = MPI_LAND;
break;
case mpi::MPI_OpClassEnum::MPI_BAND:
case mpi::MPI_ReductionOpEnum::MPI_BAND:
op = MPI_BAND;
break;
case mpi::MPI_OpClassEnum::MPI_LOR:
case mpi::MPI_ReductionOpEnum::MPI_LOR:
op = MPI_LOR;
break;
case mpi::MPI_OpClassEnum::MPI_BOR:
case mpi::MPI_ReductionOpEnum::MPI_BOR:
op = MPI_BOR;
break;
case mpi::MPI_OpClassEnum::MPI_LXOR:
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
op = MPI_LXOR;
break;
case mpi::MPI_OpClassEnum::MPI_BXOR:
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
op = MPI_BXOR;
break;
case mpi::MPI_OpClassEnum::MPI_MINLOC:
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
op = MPI_MINLOC;
break;
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
op = MPI_MAXLOC;
break;
case mpi::MPI_OpClassEnum::MPI_REPLACE:
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
op = MPI_REPLACE;
break;
}
Expand Down Expand Up @@ -336,49 +336,49 @@ class OMPIImplTraits : public MPIImplTraits {
}

Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) override {
mpi::MPI_ReductionOpEnum opAttr) override {
StringRef op;
switch (opAttr) {
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
op = "ompi_mpi_no_op";
break;
case mpi::MPI_OpClassEnum::MPI_MAX:
case mpi::MPI_ReductionOpEnum::MPI_MAX:
op = "ompi_mpi_max";
break;
case mpi::MPI_OpClassEnum::MPI_MIN:
case mpi::MPI_ReductionOpEnum::MPI_MIN:
op = "ompi_mpi_min";
break;
case mpi::MPI_OpClassEnum::MPI_SUM:
case mpi::MPI_ReductionOpEnum::MPI_SUM:
op = "ompi_mpi_sum";
break;
case mpi::MPI_OpClassEnum::MPI_PROD:
case mpi::MPI_ReductionOpEnum::MPI_PROD:
op = "ompi_mpi_prod";
break;
case mpi::MPI_OpClassEnum::MPI_LAND:
case mpi::MPI_ReductionOpEnum::MPI_LAND:
op = "ompi_mpi_land";
break;
case mpi::MPI_OpClassEnum::MPI_BAND:
case mpi::MPI_ReductionOpEnum::MPI_BAND:
op = "ompi_mpi_band";
break;
case mpi::MPI_OpClassEnum::MPI_LOR:
case mpi::MPI_ReductionOpEnum::MPI_LOR:
op = "ompi_mpi_lor";
break;
case mpi::MPI_OpClassEnum::MPI_BOR:
case mpi::MPI_ReductionOpEnum::MPI_BOR:
op = "ompi_mpi_bor";
break;
case mpi::MPI_OpClassEnum::MPI_LXOR:
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
op = "ompi_mpi_lxor";
break;
case mpi::MPI_OpClassEnum::MPI_BXOR:
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
op = "ompi_mpi_bxor";
break;
case mpi::MPI_OpClassEnum::MPI_MINLOC:
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
op = "ompi_mpi_minloc";
break;
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
op = "ompi_mpi_maxloc";
break;
case mpi::MPI_OpClassEnum::MPI_REPLACE:
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
op = "ompi_mpi_replace";
break;
}
Expand Down
Loading
Loading