diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 284ba72af9768..baa279c62a16c 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -37,26 +37,44 @@ def MPI_InitOp : MPI_Op<"init", []> { let assemblyFormat = "attr-dict (`:` type($retval)^)?"; } +//===----------------------------------------------------------------------===// +// CommWorldOp +//===----------------------------------------------------------------------===// + +def MPI_CommWorldOp : MPI_Op<"comm_world", []> { + let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`"; + let description = [{ + This operation returns the predefined MPI_COMM_WORLD communicator. + }]; + + let results = (outs MPI_Comm : $comm); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + //===----------------------------------------------------------------------===// // CommRankOp //===----------------------------------------------------------------------===// def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let summary = "Get the current rank, equivalent to " - "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; + "`MPI_Comm_rank(comm, &rank)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins Optional : $comm); + let results = ( outs Optional : $retval, I32 : $rank ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):" + "(`:`)? type(results)"; } //===----------------------------------------------------------------------===// @@ -65,20 +83,52 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let summary = "Get the size of the group associated to the communicator, " - "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`"; + "equivalent to `MPI_Comm_size(comm, &size)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins Optional : $comm); + let results = ( outs Optional : $retval, I32 : $size ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):" + "(`:`)? type(results)"; +} + +//===----------------------------------------------------------------------===// +// CommSplitOp +//===----------------------------------------------------------------------===// + +def MPI_CommSplit : MPI_Op<"comm_split", []> { + let summary = "Partition the group associated to the given communicator into " + "disjoint subgroups"; + let description = [{ + This operation splits the communicator into multiple sub-communicators. + The color value determines the group of processes that will be part of the + new communicator. The key value determines the rank of the calling process + in the new communicator. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key); + + let results = ( + outs Optional : $retval, + MPI_Comm : $newcomm + ); + + let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` " + "type($comm) `,` type($color) `,` type($key) `->` " + "type(results)"; } //===----------------------------------------------------------------------===// @@ -87,13 +137,13 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { def MPI_SendOp : MPI_Op<"send", []> { let summary = - "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Send performs a blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -102,12 +152,13 @@ def MPI_SendOp : MPI_Op<"send", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank + I32 : $rank, + Optional : $comm ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)? `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -119,14 +170,14 @@ def MPI_SendOp : MPI_Op<"send", []> { def MPI_ISendOp : MPI_Op<"isend", []> { let summary = - "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -135,7 +186,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank + I32 : $rank, + Optional : $comm ); let results = ( @@ -143,9 +195,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> { MPI_Request : $req ); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "`->` type(results)"; + "(`,` type($comm) ^)? `->` type(results)"; let hasCanonicalizer = 1; } @@ -155,14 +207,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> { def MPI_RecvOp : MPI_Op<"recv", []> { let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, " - "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + "comm, MPI_STATUS_IGNORE)`"; let description = [{ MPI_Recv performs a blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -172,14 +224,15 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let arguments = ( ins AnyMemRef : $ref, - I32 : $tag, I32 : $rank + I32 : $tag, I32 : $rank, + Optional : $comm ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" - "type($ref) `,` type($tag) `,` type($rank)" - "(`->` type($retval)^)?"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict" + " `:` type($ref) `,` type($tag) `,` type($rank) " + "(`,` type($comm) ^)? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -189,14 +242,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> { def MPI_IRecvOp : MPI_Op<"irecv", []> { let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " - "MPI_COMM_WORLD, &req)`"; + "comm, &req)`"; let description = [{ MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -205,7 +258,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank + I32 : $rank, + Optional : $comm ); let results = ( @@ -213,9 +267,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { MPI_Request : $req ); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" - "type($ref) `,` type($tag) `,` type($rank) `->`" - "type(results)"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($rank)" + "(`,` type($comm) ^)? `->` type(results)"; let hasCanonicalizer = 1; } @@ -224,8 +278,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { //===----------------------------------------------------------------------===// def MPI_AllReduceOp : MPI_Op<"allreduce", []> { - let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, " - "MPI_COMM_WORLD)`"; + let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`"; let description = [{ MPI_Allreduce performs a reduction operation on the values in the sendbuf array and stores the result in the recvbuf array. The operation is @@ -235,7 +288,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are supported. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -244,14 +297,15 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, - MPI_OpClassAttr : $op + MPI_OpClassAttr : $op, + Optional : $comm ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`" - "type($sendbuf) `,` type($recvbuf)" - "(`->` type($retval)^)?"; + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` " + "attr-dict `:` type($sendbuf) `,` type($recvbuf) " + "(`,` type($comm) ^)? (`->` type($retval)^)?"; } //===----------------------------------------------------------------------===// @@ -259,20 +313,33 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { //===----------------------------------------------------------------------===// def MPI_Barrier : MPI_Op<"barrier", []> { - let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`"; + let summary = "Equivalent to `MPI_Barrier(comm)`"; let description = [{ MPI_Barrier blocks execution until all processes in the communicator have reached this routine. - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins Optional : $comm); + let results = (outs Optional:$retval); - let assemblyFormat = "attr-dict (`:` type($retval) ^)?"; + // TODO fix assembly format + // let assemblyFormat = "(" + // "(attr-dict) ^" + // "(attr-dict `:` type($retval)) ^" + // "(`(` $comm `)` attr-dict `:` type($comm)) ^" + // "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))" + // ")?"; + let assemblyFormat = [{ + (`(` $comm ^ `)`)? attr-dict + (`:` type($comm) ^ `->`):(`:`)? + type(results) + }]; } //===----------------------------------------------------------------------===// @@ -295,8 +362,7 @@ def MPI_Wait : MPI_Op<"wait", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) " - "(`->` type($retval) ^)?"; + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index fafea0eac8bb7..868132a62abc4 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { }]; } +//===----------------------------------------------------------------------===// +// mpi::CommType +//===----------------------------------------------------------------------===// + +def MPI_Comm : MPI_Type<"Comm", "comm"> { + let summary = "MPI communicator handler"; + let description = [{ + This type represents a handler to the MPI communicator. + }]; +} + //===----------------------------------------------------------------------===// // mpi::RequestType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index f23a7e18a2ee9..fad203ded1d06 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -6,44 +6,89 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK: %0 = mpi.init : !mpi.retval %err = mpi.init : !mpi.retval + // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm + %comm = mpi.comm_world : !mpi.comm + + // CHECK-NEXT: %rank = mpi.comm_rank : i32 + %rank = mpi.comm_rank : i32 + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> i32 + %rank = mpi.comm_rank(%comm) : !mpi.comm -> i32 + + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> !mpi.retval, i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.comm -> !mpi.retval, i32 + + // CHECK-NEXT: %size = mpi.comm_size : i32 + %size = mpi.comm_size : i32 + // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32 %retval_0, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: %size = mpi.comm_size : !mpi.comm -> i32 + %size = mpi.comm_size(%comm) : !mpi.comm -> i32 + + // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32 + %retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32 + + // CHECK-NEXT: %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.comm + %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 + + // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm + %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm + mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm + // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm + mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm + // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> !mpi.request + %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request + // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request + %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request + // CHECK-NEXT: mpi.wait(%req) : !mpi.request mpi.wait(%req) : !mpi.request // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval - // CHECK-NEXT: mpi.barrier : !mpi.retval - mpi.barrier : !mpi.retval + // CHECK-NEXT: mpi.barrier + mpi.barrier + + // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval + %err7 = mpi.barrier : !mpi.retval + + // CHECK-NEXT: mpi.barrier(%comm) + mpi.barrier(%comm) // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval %err7 = mpi.barrier : !mpi.retval @@ -54,6 +99,9 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, ) : memref<100xf32>, memref<100xf32> -> !mpi.retval %err8 = mpi.allreduce(%ref, %ref, ) : memref<100xf32>, memref<100xf32> -> !mpi.retval + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm + mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm + // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval