Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
1fbae54
Add `MPI_Comm`, `MPI_Request`, `MPI_Status`, `MPI_Op` type definitions
mofeing Jan 16, 2025
dc84ca4
Add `MPI_CommSize`, `MPI_ISend`, `MPI_IRecv` ops
mofeing Jan 16, 2025
2ee10ab
Fix typo
mofeing Jan 16, 2025
539bf43
Finish types
mofeing Jan 25, 2025
662998d
Define `MPI_Op` enum & attr
mofeing Jan 26, 2025
c1ec63c
Add communicator argument to mpi ops as optional input argument
mofeing Jan 26, 2025
7eda791
Add summary of new mpi types
mofeing Jan 26, 2025
b97a541
format code
mofeing Jan 26, 2025
d5725a8
Add `mpi.comm_split` op
mofeing Jan 26, 2025
1a68b34
Add `mpi.barrier` op
mofeing Jan 26, 2025
80a4259
Format code
mofeing Jan 26, 2025
cfb81af
Fix ops returning `MPI_Request`
mofeing Jan 26, 2025
740cf0b
Add `mpi.wait` op
mofeing Jan 26, 2025
1af1425
Add `mpi.allreduce` op
mofeing Jan 26, 2025
c11a60f
Fix assembly formats
mofeing Jan 27, 2025
d971d83
add some tests
mofeing Jan 27, 2025
beb5764
Fix input specifier
mofeing Jan 27, 2025
2317994
Comment predefined constant MPI_Ops
mofeing Jan 27, 2025
63ccc33
Replace `MPI_Op` new type for region
mofeing Jan 27, 2025
d318c60
Go back to only use predefined MPI_Ops
mofeing Jan 28, 2025
8e3aa18
Remove `MPI_Operation` type
mofeing Jan 29, 2025
9c708d4
Add `mpi.comm_world` op to return `MPI_COMM_WORLD`
mofeing Jan 29, 2025
326b13f
Add tests
mofeing Jan 29, 2025
2baf33f
Merge branch 'main' into mlir-mpi
mofeing Jan 29, 2025
1fd5578
Fix anchor of assembly format
mofeing Jan 29, 2025
016b856
Fix more anchors
mofeing Jan 29, 2025
1931b8e
Fix anchors again
mofeing Jan 29, 2025
aec9fbd
fix another anchor
mofeing Jan 29, 2025
d4684fb
fix optional format of `MPI_BarrierOp`
mofeing Jan 29, 2025
794fa25
fix more anchors
mofeing Jan 29, 2025
f0d0f44
fix anchors in `MPI_ISendOp` and `MPI_IRecvOp`
mofeing Jan 29, 2025
92f2cca
fix format
mofeing Jan 29, 2025
3688915
Define `getCanonicalizationPatterns` for `ISendOp`, `IRecvOp`, `AllRe…
mofeing Jan 29, 2025
3abe925
remove duplicated `getCanonicalizationPatterns`
mofeing Jan 29, 2025
7a9fa9c
Remove canonicalization for `AllReduceOp`
mofeing Jan 29, 2025
1926bda
fix test
mofeing Jan 29, 2025
89ec111
fix some assembly formats
mofeing Jan 29, 2025
30fb673
fix syntax
mofeing Jan 29, 2025
6abba5a
Remove MPI_Comm type
mofeing Jan 29, 2025
452f760
fix tests
mofeing Jan 29, 2025
56868e8
change order of results of `MPI_CommRankOp`
mofeing Jan 29, 2025
b9988b3
format code
mofeing Jan 29, 2025
2075c02
format code
mofeing Jan 29, 2025
8477428
refactor assembly format of `isend`, `irecv` and fix tests
mofeing Jan 30, 2025
1259cbc
last fixes
mofeing Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPI.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,44 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
let assemblyFormat = "`<` $value `>`";
}

// TODO is it ok to have them as I32?
// def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
// def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
// def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
// def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
// def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
// def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
// def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
// def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
// def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
// def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
// def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
// def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
// def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
// def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;

// def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
// MPI_OpNull,
// MPI_OpMax,
// MPI_OpMin,
// MPI_OpSum,
// MPI_OpProd,
// MPI_OpLand,
// MPI_OpBand,
// MPI_OpLor,
// MPI_OpBor,
// MPI_OpLxor,
// MPI_OpBxor,
// MPI_OpMinloc,
// MPI_OpMaxloc,
// MPI_OpReplace
// ]> {
// let genSpecializedAttr = 0;
// let cppNamespace = "::mlir::mpi";
// }

// def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
// let assemblyFormat = "`<` $value `>`";
// }

#endif // MLIR_DIALECT_MPI_IR_MPI_TD
239 changes: 222 additions & 17 deletions mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,73 @@ def MPI_InitOp : MPI_Op<"init", []> {

def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);

let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//

def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);

let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSplitOp
//===----------------------------------------------------------------------===//

def MPI_CommSplit : MPI_Op<"comm_split", []> {
let summary = "Partition the group associated to the given communicator into "
"disjoint subgroups";
let description = [{
This operation splits the communicator into multiple sub-communicators.
The color value determines the group of processes that will be part of the
new communicator. The key value determines the rank of the calling process
in the new communicator.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);

let results = (
outs Optional<MPI_Retval> : $retval,
MPI_Comm : $newcomm
);

let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
Expand All @@ -65,59 +118,214 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {

def MPI_SendOp : MPI_Op<"send", []> {
let summary =
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)? `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// ISendOp
//===----------------------------------------------------------------------===//

def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval, MPI_Request : $req);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// RecvOp
//===----------------------------------------------------------------------===//

def MPI_RecvOp : MPI_Op<"recv", []> {
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
"comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"(`,` type($comm))? (`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// IRecvOp
//===----------------------------------------------------------------------===//

def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
"comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval, MPI_Request : $req);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank)"
"(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
performed across all processes in the communicator.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
Optional<MPI_Comm> : $comm
);

let regions = (region SizedRegion<1>:$op);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tobiasgrosser is there some way to impose a trait to input regions? speaking with @vchuravy, he suggested that we might need the region (that implements the MPI op) to be "pure"; i.e. independent from the whatever is above the region.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You add IsolatedFromAbove to the traits of an Op.


let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) `,` "
"type($op) (`,` type($comm))? (`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//

def MPI_Barrier : MPI_Op<"barrier", []> {
let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type($retval)^";
}

//===----------------------------------------------------------------------===//
// WaitOp
//===----------------------------------------------------------------------===//

def MPI_Wait : MPI_Op<"wait", []> {
let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Wait blocks execution until the request has completed.

The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Request : $req);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) `->` type($retval)^";
}

//===----------------------------------------------------------------------===//
// FinalizeOp
Expand All @@ -139,7 +347,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}


//===----------------------------------------------------------------------===//
// RetvalCheckOp
//===----------------------------------------------------------------------===//
Expand All @@ -163,10 +370,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
}



//===----------------------------------------------------------------------===//
// RetvalCheckOp
// ErrorClassOp
//===----------------------------------------------------------------------===//

def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
Expand Down
Loading