Skip to content

Commit 1bd063b

Browse files
authored
use cfunction closures for user-defined operators (#284)
* use cfunction closures for user-defined operators * add tests for sum of DoubleFloats * precompile test environment to avoid race conditions
1 parent e0c2578 commit 1bd063b

File tree

9 files changed

+117
-85
lines changed

9 files changed

+117
-85
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
1414
julia = "1"
1515

1616
[extras]
17+
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1820
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1921

2022
[targets]
21-
test = ["LinearAlgebra", "Test"]
23+
test = ["DoubleFloats", "LinearAlgebra", "Pkg", "Test"]

src/MPI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ include("handle.jl")
3131
include("info.jl")
3232
include("comm.jl")
3333
include("environment.jl")
34-
include("operators.jl")
3534
include("datatypes.jl")
35+
include("operators.jl")
3636
include("pointtopoint.jl")
3737
include("collective.jl")
3838
include("topology.jl")

src/collective.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,11 @@ end
9292

9393
# Convert user-provided functions to MPI.Op
9494
Reduce!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
95-
count::Integer, opfunc::Function, root::Integer,
96-
comm::Comm) where {T} =
97-
Reduce!(sendbuf, recvbuf, count, user_op(opfunc), root, comm)
95+
count::Integer, opfunc, root::Integer, comm::Comm) where {T} =
96+
Reduce!(sendbuf, recvbuf, count, Op(opfunc, T), root, comm)
9897

9998
function Reduce!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
100-
op::Union{Op, Function}, root::Integer, comm::Comm) where T
99+
op, root::Integer, comm::Comm) where T
101100
Reduce!(sendbuf, recvbuf, length(sendbuf), op, root, comm)
102101
end
103102

@@ -113,25 +112,26 @@ To specify the output buffer, see [`Reduce!`](@ref).
113112
To perform the reduction in place, see [`Reduce_in_place!`](@ref).
114113
"""
115114
function Reduce(sendbuf::MPIBuffertype{T}, count::Integer,
116-
op::Union{Op, Function}, root::Integer, comm::Comm) where T
115+
op, root::Integer, comm::Comm) where T
117116
isroot = Comm_rank(comm) == root
118117
recvbuf = Array{T}(undef, isroot ? count : 0)
119118
Reduce!(sendbuf, recvbuf, count, op, root, comm)
120119
end
121120

122-
function Reduce(sendbuf::Array{T,N}, op::Union{Op,Function},
121+
function Reduce(sendbuf::Array{T,N}, op,
123122
root::Integer, comm::Comm) where {T,N}
124123
isroot = Comm_rank(comm) == root
125124
recvbuf = Array{T,N}(undef, isroot ? size(sendbuf) : Tuple(zeros(Int, ndims(sendbuf))))
126125
Reduce!(sendbuf, recvbuf, length(sendbuf), op, root, comm)
127126
end
128127

129-
function Reduce(sendbuf::SubArray{T}, op::Union{Op,Function}, root::Integer, comm::Comm) where T
128+
function Reduce(sendbuf::SubArray{T}, op, root::Integer, comm::Comm) where T
130129
@assert Base.iscontiguous(sendbuf)
131130
Reduce(sendbuf, length(sendbuf), op, root, comm)
132131
end
133132

134-
function Reduce(object::T, op::Union{Op,Function}, root::Integer, comm::Comm) where T
133+
function Reduce(object::T, op
134+
, root::Integer, comm::Comm) where T
135135
isroot = Comm_rank(comm) == root
136136
sendbuf = T[object]
137137
recvbuf = Reduce(sendbuf, op, root, comm)
@@ -174,9 +174,9 @@ function Reduce_in_place!(buf::MPIBuffertype{T}, count::Integer,
174174
end
175175

176176
# Convert to MPI.Op
177-
Reduce_in_place!(buf::MPIBuffertype{T}, count::Integer, op::Function,
177+
Reduce_in_place!(buf::MPIBuffertype{T}, count::Integer, op,
178178
root::Integer, comm::Comm) where T =
179-
Reduce_in_place!(buf, count, user_op(op), root, comm)
179+
Reduce_in_place!(buf, count, Op(op,T), root, comm)
180180

181181
"""
182182
Allreduce!(sendbuf, recvbuf[, count=length(sendbuf)], op, comm)
@@ -206,11 +206,11 @@ end
206206

207207
# Convert user-provided functions to MPI.Op
208208
Allreduce!(sendbuf::MPIBuffertypeOrConst{T}, recvbuf::MPIBuffertype{T},
209-
count::Integer, opfunc::Function, comm::Comm) where {T} =
210-
Allreduce!(sendbuf, recvbuf, count, user_op(opfunc), comm)
209+
count::Integer, opfunc, comm::Comm) where {T} =
210+
Allreduce!(sendbuf, recvbuf, count, Op(opfunc,T), comm)
211211

212212
function Allreduce!(sendbuf::MPIBuffertypeOrConst{T}, recvbuf::MPIBuffertype{T},
213-
op::Union{Op,Function}, comm::Comm) where T
213+
op, comm::Comm) where T
214214
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
215215
end
216216

@@ -222,7 +222,7 @@ the results on all the processes in the group.
222222
223223
Equivalent to calling `Allreduce!(MPI.IN_PLACE, buf, op, comm)`
224224
"""
225-
function Allreduce!(buf::MPIBuffertype{T}, op::Union{Op, Function}, comm::Comm) where T
225+
function Allreduce!(buf::MPIBuffertype{T}, op, comm::Comm) where T
226226
Allreduce!(MPI.IN_PLACE, buf, length(buf), op, comm)
227227
end
228228

@@ -234,18 +234,18 @@ output buffer in all processes of the group.
234234
235235
To specify the output buffer or perform the operation in pace, see [`Allreduce!`](@ref).
236236
"""
237-
function Allreduce(sendbuf::MPIBuffertype{T}, op::Union{Op,Function}, comm::Comm) where T
237+
function Allreduce(sendbuf::MPIBuffertype{T}, op, comm::Comm) where T
238238

239239
recvbuf = similar(sendbuf)
240240
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
241241
end
242242

243-
function Allreduce(sendbuf::Array{T, N}, op::Union{Op, Function}, comm::Comm) where {T, N}
243+
function Allreduce(sendbuf::Array{T, N}, op, comm::Comm) where {T, N}
244244
recvbuf = Array{T,N}(undef, size(sendbuf))
245245
Allreduce!(sendbuf, recvbuf, length(sendbuf), op, comm)
246246
end
247247

248-
function Allreduce(obj::T, op::Union{Op,Function}, comm::Comm) where T
248+
function Allreduce(obj::T, op, comm::Comm) where T
249249
objref = Ref(obj)
250250
outref = Ref{T}()
251251
Allreduce!(objref, outref, 1, op, comm)
@@ -255,7 +255,7 @@ end
255255

256256
# Deprecation warning for lowercase allreduce that was used until v. 0.7.2
257257
# Should be removed at some point in the future
258-
function allreduce(sendbuf::MPIBuffertype{T}, op::Union{Op,Function},
258+
function allreduce(sendbuf::MPIBuffertype{T}, op,
259259
comm::Comm) where T
260260
@warn "`allreduce` is deprecated, use `Allreduce` instead."
261261
Allreduce(sendbuf, op, comm)

src/datatypes.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
const DATATYPE_NULL = _Datatype(MPI_DATATYPE_NULL)
44
Datatype() = Datatype(DATATYPE_NULL.val)
55

6+
const MPIInteger = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64}
7+
const MPIFloatingPoint = Union{Float32, Float64}
8+
const MPIComplex = Union{ComplexF32, ComplexF64}
9+
610
const MPIDatatype = Union{Char,
711
Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
812
UInt64,

src/handle.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ macro mpi_handle(def, extrafields...)
88
Base.@__doc__ mutable struct $(esc(def))
99
val::$mpiname
1010
$(extrafields...)
11+
12+
$(esc(name))(val::$mpiname, $(extrafields...)) = new(val, $(extrafields...))
1113
end
1214

1315
# const initializer
@@ -29,6 +31,5 @@ macro mpi_handle(def, extrafields...)
2931
function Base.unsafe_convert(::Type{Ptr{$mpiname}}, obj::$name)
3032
convert(Ptr{$mpiname}, pointer_from_objref(obj))
3133
end
32-
3334
end
3435
end

src/operators.jl

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,87 @@
1-
@mpi_handle Op
1+
"""
2+
Op
23
3-
const OP_NULL = _Op(MPI_OP_NULL)
4-
const BAND = _Op(MPI_BAND)
5-
const BOR = _Op(MPI_BOR)
6-
const BXOR = _Op(MPI_BXOR)
7-
const LAND = _Op(MPI_LAND)
8-
const LOR = _Op(MPI_LOR)
9-
const LXOR = _Op(MPI_LXOR)
10-
const MAX = _Op(MPI_MAX)
11-
const MIN = _Op(MPI_MIN)
12-
const NO_OP = _Op(MPI_NO_OP)
13-
const PROD = _Op(MPI_PROD)
14-
const REPLACE = _Op(MPI_REPLACE)
15-
const SUM = _Op(MPI_SUM)
4+
An MPI reduction operator, for use with [`Reduce!`](@ref)/[`Reduce`](@ref),
5+
[`Reduce_in_place!`](@ref), [`Allreduce!`](@ref)/[`Allreduce`](@ref), [`Scan`](@ref) or [`Exscan`](@ref).
166
7+
## Usage
178
9+
Op(op, T=Any; iscommutative=false)
1810
19-
import Base.Threads: nthreads, threadid
11+
Wrap the Julia reduction function `op` for arguments of type `T`. `op` is assumed to be
12+
associative, and if `iscommutative` is true, assumed to be commutative as well.
13+
"""
14+
@mpi_handle Op fptr
2015

21-
# Implement user-defined MPI reduction operations, by passing Julia
22-
# functions as callbacks to MPI.
16+
const OP_NULL = _Op(MPI_OP_NULL, nothing)
17+
const BAND = _Op(MPI_BAND, nothing)
18+
const BOR = _Op(MPI_BOR, nothing)
19+
const BXOR = _Op(MPI_BXOR, nothing)
20+
const LAND = _Op(MPI_LAND, nothing)
21+
const LOR = _Op(MPI_LOR, nothing)
22+
const LXOR = _Op(MPI_LXOR, nothing)
23+
const MAX = _Op(MPI_MAX, nothing)
24+
const MIN = _Op(MPI_MIN, nothing)
25+
const PROD = _Op(MPI_PROD, nothing)
26+
const REPLACE = _Op(MPI_REPLACE, nothing)
27+
const SUM = _Op(MPI_SUM, nothing)
2328

24-
# Unfortunately, MPI_Op_create takes a function that does not accept
25-
# a void* "thunk" parameter, making it impossible to fully simulate
26-
# a closure. So, we have to use a global variable instead. (Since
27-
# the reduction functions are barriers, being re-entrant is probably
28-
# not important in practice, fortunately.) For MPI_THREAD_MULTIPLE
29-
# using Julia native threading, however, we do make this global thread-local
30-
const _user_functions = Array{Function}(undef, 1) # resized to nthreads() at runtime
31-
const _user_op = _Op(MPI_OP_NULL)
29+
if @isdefined(MPI_NO_OP)
30+
const NO_OP = _Op(MPI_NO_OP, nothing)
31+
end
32+
33+
Op(::typeof(min), ::Type{T}; iscommutative=true) where {T<:Union{MPIInteger,MPIFloatingPoint}} = MIN
34+
Op(::typeof(max), ::Type{T}; iscommutative=true) where {T<:Union{MPIInteger,MPIFloatingPoint}} = MAX
35+
Op(::typeof(+), ::Type{T}; iscommutative=true) where {T<:Union{MPIInteger,MPIFloatingPoint,MPIComplex}} = SUM
36+
Op(::typeof(*), ::Type{T}; iscommutative=true) where {T<:Union{MPIInteger,MPIFloatingPoint,MPIComplex}} = PROD
37+
Op(::typeof(&), ::Type{T}; iscommutative=true) where {T<:MPIInteger} = BAND
38+
Op(::typeof(|), ::Type{T}; iscommutative=true) where {T<:MPIInteger} = BOR
39+
Op(::typeof(), ::Type{T}; iscommutative=true) where {T<:MPIInteger} = BXOR
40+
41+
42+
function free(op::Op)
43+
if op.val != OP_NULL.val
44+
@mpichk ccall((:MPI_Op_free, libmpi), Cint, (Ptr{MPI_Op},), op)
45+
refcount_dec()
46+
end
47+
return nothing
48+
end
3249

33-
# C callback function corresponding to MPI_User_function
34-
function _mpi_user_function(_a::Ptr{Nothing}, _b::Ptr{Nothing}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype})
50+
struct OpWrapper{F,T}
51+
f::F
52+
end
53+
54+
function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T}
3555
len = unsafe_load(_len)
36-
T = mpitype_dict_inverse[unsafe_load(t)]
37-
a = Ptr{T}(_a)
38-
b = Ptr{T}(_b)
39-
f = _user_functions[threadid()]
56+
if isconcretetype(T)
57+
S = T
58+
else
59+
S = mpitype_dict_inverse[unsafe_load(t)]
60+
end
61+
a = Ptr{S}(_a)
62+
b = Ptr{S}(_b)
4063
for i = 1:len
41-
unsafe_store!(b, f(unsafe_load(a,i), unsafe_load(b,i)), i)
64+
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
4265
end
4366
return nothing
4467
end
4568

46-
function user_op(opfunc::Function)
69+
70+
function Op(f, T=Any; iscommutative=false)
4771
if Sys.iswindows() && Sys.WORD_SIZE == 32
48-
error("Custom reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
72+
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
4973
end
74+
w = OpWrapper{typeof(f),T}(f)
75+
fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
5076

51-
# we must initialize these at runtime, but it can't be done in __init__
52-
# since MPI.Init is not called yet. So we do it lazily here:
53-
if _user_op.val == OP_NULL.val
54-
# FIXME: to be thread-safe, there should really be a mutex lock
55-
# of some sort so that this initialization only occurs once.
56-
# To do when native threading in Julia stabilizes (and is documented).
57-
resize!(_user_functions, nthreads())
58-
user_function = @cfunction(_mpi_user_function, Nothing, (Ptr{Nothing}, Ptr{Nothing}, Ptr{Cint}, Ptr{MPI_Datatype}))
59-
@mpichk ccall((:MPI_Op_create, libmpi), Cint,
60-
(Ptr{Cvoid}, Ref{Cint}, Ptr{MPI_Op}),
61-
user_function, false, _user_op)
62-
end
77+
op = Op(OP_NULL.val, fptr)
78+
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
79+
@mpichk ccall((:MPI_Op_create, libmpi), Cint,
80+
(Ptr{Cvoid}, Cint, Ptr{MPI_Op}),
81+
fptr, iscommutative, op)
6382

64-
_user_functions[threadid()] = opfunc
65-
return _user_op
83+
refcount_inc()
84+
finalizer(free, op)
85+
return op
6686
end
6787

68-
69-
70-
# Match predefined ops to Julia functions
71-
for (f,op) in [
72-
(+, :SUM),
73-
(*, :PROD),
74-
(min, :MIN),
75-
(max, :MAX),
76-
(&, :BAND),
77-
(|, :BOR),
78-
(, :BXOR),
79-
]
80-
@eval user_op(::$(typeof(f))) = $op.val
81-
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using Pkg
2+
pkg"precompile"
3+
14
using MPI
25
using Test
36

test/test_allreduce.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,7 @@ for typ=[Int]
5959
end
6060

6161
MPI.Barrier( MPI.COMM_WORLD )
62+
63+
GC.gc()
6264
MPI.Finalize()
6365
@test MPI.Finalized()

test/test_reduce.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22

33
using MPI
4-
using LinearAlgebra
4+
using LinearAlgebra, DoubleFloats
55

66
MPI.Init()
77

@@ -86,7 +86,21 @@ for typ=[Int]
8686
end
8787
end
8888

89-
9089
MPI.Barrier( MPI.COMM_WORLD )
90+
91+
if !(Sys.iswindows() && Sys.WORD_SIZE == 32)
92+
93+
send_arr = [Double64(i)/10 for i = 1:10]
94+
95+
if rank == root
96+
@test MPI.Reduce(send_arr, +, root, MPI.COMM_WORLD) [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
97+
else
98+
@test MPI.Reduce(send_arr, +, root, MPI.COMM_WORLD) === nothing
99+
end
100+
101+
MPI.Barrier( MPI.COMM_WORLD )
102+
end
103+
104+
GC.gc()
91105
MPI.Finalize()
92106
@test MPI.Finalized()

0 commit comments

Comments
 (0)