Skip to content

Commit 5da9199

Browse files
committed
Use mpi. prefix for ops
1 parent 9358d05 commit 5da9199

File tree

10 files changed

+20
-21
lines changed

10 files changed

+20
-21
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ def MPICommRankOp : EnzymeXLA_Op<"mpi.comm_rank", []> {
10701070
let assemblyFormat = "attr-dict `:` type(results)";
10711071
}
10721072

1073-
def MPICommSizeOp : EnzymeXLA_Op<"comm_size", []> {
1073+
def MPICommSizeOp : EnzymeXLA_Op<"mpi.comm_size", []> {
10741074
let summary = "Equivalent to MPI_Comm_size(MPI_COMM_WORLD, &size)";
10751075

10761076
let arguments = (
@@ -1084,13 +1084,12 @@ def MPICommSizeOp : EnzymeXLA_Op<"comm_size", []> {
10841084
let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
10851085
}
10861086

1087-
// We call it "mpi_barrier" to not collide with BarrierOp above
1088-
def MPIBarrierOp : EnzymeXLA_Op<"mpi_barrier", []> {
1087+
def MPIBarrierOp : EnzymeXLA_Op<"mpi.barrier", []> {
10891088
let summary = "Equivalent to MPI_Barrier(MPI_COMM_WORLD)";
10901089
let assemblyFormat = "attr-dict";
10911090
}
10921091

1093-
def MPISendOp : EnzymeXLA_Op<"send", []> {
1092+
def MPISendOp : EnzymeXLA_Op<"mpi.send", []> {
10941093
let summary = "Equivalent to "
10951094
"`MPI_Send(&buf, count, datatype, dest, tag, comm)`";
10961095

@@ -1105,7 +1104,7 @@ def MPISendOp : EnzymeXLA_Op<"send", []> {
11051104
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
11061105
}
11071106

1108-
def MPIRecvOp : EnzymeXLA_Op<"recv", []> {
1107+
def MPIRecvOp : EnzymeXLA_Op<"mpi.recv", []> {
11091108
let summary = "Equivalent to "
11101109
"`MPI_Recv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
11111110

@@ -1124,7 +1123,7 @@ def MPIRecvOp : EnzymeXLA_Op<"recv", []> {
11241123
let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
11251124
}
11261125

1127-
def MPIIsendOp : EnzymeXLA_Op<"isend", []> {
1126+
def MPIIsendOp : EnzymeXLA_Op<"mpi.isend", []> {
11281127
let summary = "Equivalent to "
11291128
"`MPI_Isend(&buf, count, datatype, dest, tag, MPI_COMM_WORLD, &request)`";
11301129

@@ -1144,7 +1143,7 @@ def MPIIsendOp : EnzymeXLA_Op<"isend", []> {
11441143
let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
11451144
}
11461145

1147-
def MPIIrecvOp : EnzymeXLA_Op<"irecv", []> {
1146+
def MPIIrecvOp : EnzymeXLA_Op<"mpi.irecv", []> {
11481147
let summary = "Equivalent to "
11491148
"`MPI_Irecv(&buf, count, datatype, source, tag, MPI_COMM_WORLD, &request)`";
11501149

@@ -1165,14 +1164,14 @@ def MPIIrecvOp : EnzymeXLA_Op<"irecv", []> {
11651164
let assemblyFormat = "`(` operands `)` attr-dict `:` functional-type(operands, results)";
11661165
}
11671166

1168-
def MPIWaitOp : EnzymeXLA_Op<"wait", []> {
1167+
def MPIWaitOp : EnzymeXLA_Op<"mpi.wait", []> {
11691168
let summary = "Equivalent to "
11701169
"`MPI_Wait(&request, &status)`";
11711170
let arguments = (ins TensorOf<[I64]> : $inrequest);
11721171
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
11731172
}
11741173

1175-
def MPIAllreduceOp : EnzymeXLA_Op<"allreduce", []> {
1174+
def MPIAllreduceOp : EnzymeXLA_Op<"mpi.allreduce", []> {
11761175
let summary = "Equivalent to "
11771176
"`MPI_Allreduce(&sendbuf, &recvbuf, count, datatype, op, MPI_COMM_WORLD)`";
11781177

test/lit_tests/mpi/allreduce.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module {
44
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<i64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
55
%c = stablehlo.constant dense<0> : tensor<i64>
66
%c_0 = stablehlo.constant dense<1> : tensor<i32>
7-
%0 = enzymexla.allreduce(%arg0, %c, %c_0) {datatype = "MPI_SOME_TYPE", op="MPI_SOME_OP"} : (tensor<i64>, tensor<i64>, tensor<i32>) -> tensor<i64>
7+
%0 = enzymexla.mpi.allreduce(%arg0, %c, %c_0) {datatype = "MPI_SOME_TYPE", op="MPI_SOME_OP"} : (tensor<i64>, tensor<i64>, tensor<i32>) -> tensor<i64>
88
return %0 : tensor<i64>
99
}
1010
}

test/lit_tests/mpi/barrier.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module {
44
func.func @main() attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
5-
enzymexla.mpi_barrier
5+
enzymexla.mpi.barrier
66
return
77
}
88
}

test/lit_tests/mpi/comm_size.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module {
44
func.func @main() -> tensor<i32> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
55
%c = stablehlo.constant dense<-1> : tensor<i32>
6-
%0 = enzymexla.comm_size(%c) : (tensor<i32>) -> tensor<i32>
6+
%0 = enzymexla.mpi.comm_size(%c) : (tensor<i32>) -> tensor<i32>
77
return %0 : tensor<i32>
88
}
99
}

test/lit_tests/mpi/irecv-wait.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ module {
77
%c_0 = stablehlo.constant dense<42> : tensor<i32>
88
%c_1 = stablehlo.constant dense<5> : tensor<i32>
99
%c_2 = stablehlo.constant dense<-1> : tensor<i64>
10-
%outbuf, %outrequest = enzymexla.irecv(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_DOUBLE"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> (tensor<5xf64>, tensor<i64>)
11-
enzymexla.wait(%outrequest) : tensor<i64>
10+
%outbuf, %outrequest = enzymexla.mpi.irecv(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_DOUBLE"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> (tensor<5xf64>, tensor<i64>)
11+
enzymexla.mpi.wait(%outrequest) : tensor<i64>
1212
%1 = stablehlo.transpose %outbuf, dims = [0] : (tensor<5xf64>) -> tensor<5xf64>
1313
return %1 : tensor<5xf64>
1414
}

test/lit_tests/mpi/irecv.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ module {
77
%c_0 = stablehlo.constant dense<42> : tensor<i32>
88
%c_1 = stablehlo.constant dense<5> : tensor<i32>
99
%c_2 = stablehlo.constant dense<-1> : tensor<i64>
10-
%outbuf, %outrequest = enzymexla.irecv(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> (tensor<5xf64>, tensor<i64>)
11-
// enzymexla.wait(%outrequest) : tensor<i64>
10+
%outbuf, %outrequest = enzymexla.mpi.irecv(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> (tensor<5xf64>, tensor<i64>)
11+
// enzymexla.mpi.wait(%outrequest) : tensor<i64>
1212
%1 = stablehlo.transpose %outbuf, dims = [0] : (tensor<5xf64>) -> tensor<5xf64>
1313
return %1 : tensor<5xf64>
1414
}

test/lit_tests/mpi/isend.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ module {
77
%c_0 = stablehlo.constant dense<42> : tensor<i32>
88
%c_1 = stablehlo.constant dense<5> : tensor<i32>
99
%c_2 = stablehlo.constant dense<-1> : tensor<i64>
10-
%1 = enzymexla.isend(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> tensor<i64>
11-
// enzymexla.wait(%1) : tensor<i64>
10+
%1 = enzymexla.mpi.isend(%0, %c_1, %c, %c_0, %c_2) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i64>) -> tensor<i64>
11+
// enzymexla.mpi.wait(%1) : tensor<i64>
1212
%2 = stablehlo.transpose %0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64>
1313
return %2 : tensor<5xf64>
1414
}

test/lit_tests/mpi/recv.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module {
66
%c = stablehlo.constant dense<43> : tensor<i32>
77
%c_0 = stablehlo.constant dense<0> : tensor<i32>
88
%c_1 = stablehlo.constant dense<5> : tensor<i32>
9-
%1 = enzymexla.recv(%0, %c_1, %c_0, %c) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xf64>
9+
%1 = enzymexla.mpi.recv(%0, %c_1, %c_0, %c) {datatype = "MPI_XYZ"} : (tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xf64>
1010
%2 = stablehlo.transpose %1, dims = [0] : (tensor<5xf64>) -> tensor<5xf64>
1111
return %2 : tensor<5xf64>
1212
}

test/lit_tests/mpi/send.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module {
66
%c = stablehlo.constant dense<43> : tensor<i32>
77
%c_0 = stablehlo.constant dense<1> : tensor<i32>
88
%c_1 = stablehlo.constant dense<5> : tensor<i32>
9-
enzymexla.send(%0, %c_1, %c_0, %c) {datatype = "MPI_XYZ"} : tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>
9+
enzymexla.mpi.send(%0, %c_1, %c_0, %c) {datatype = "MPI_XYZ"} : tensor<5xf64>, tensor<i32>, tensor<i32>, tensor<i32>
1010
%1 = stablehlo.transpose %0, dims = [0] : (tensor<5xf64>) -> tensor<5xf64>
1111
return %1 : tensor<5xf64>
1212
}

test/lit_tests/mpi/wait.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module {
44
func.func @main() {
55
%c_2 = stablehlo.constant dense<-1> : tensor<i64>
6-
enzymexla.wait(%c_2) : tensor<i64>
6+
enzymexla.mpi.wait(%c_2) : tensor<i64>
77
return
88
}
99
}

0 commit comments

Comments
 (0)