Skip to content

Commit c91d5a7

Browse files
vchuravygiordano
andauthored
add macro to create custom Ops also on aarch64 (#871)
Co-authored-by: Mosè Giordano <[email protected]>
1 parent aac9688 commit c91d5a7

File tree

6 files changed

+120
-26
lines changed

6 files changed

+120
-26
lines changed

docs/examples/03-reduce.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,23 @@ function pool(S1::SummaryStat, S2::SummaryStat)
3131
SummaryStat(m,v,n)
3232
end
3333

34+
# Register the custom reduction operator. This is necessary only on platforms
35+
# where Julia doesn't support closures as cfunctions (e.g. ARM), but can be used
36+
# on all platforms for consistency.
37+
MPI.@RegisterOp(pool, SummaryStat)
38+
3439
X = randn(10,3) .* [1,3,7]'
3540

3641
# Perform a scalar reduction
37-
summ = MPI.Reduce(SummaryStat(X), pool, root, comm)
42+
summ = MPI.Reduce(SummaryStat(X), pool, comm; root)
3843

3944
if MPI.Comm_rank(comm) == root
4045
@show summ.var
4146
end
4247

4348
# Perform a vector reduction:
4449
# the reduction operator is applied elementwise
45-
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, root, comm)
50+
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, comm; root)
4651

4752
if MPI.Comm_rank(comm) == root
4853
col_var = map(summ -> summ.var, col_summ)

docs/src/knownissues.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,5 @@ However they have two limitations:
210210
211211
* [Julia's C-compatible function pointers](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/index.html#Creating-C-Compatible-Julia-Function-Pointers-1) cannot be used where the `stdcall` calling convention is expected, which is the case for 32-bit Microsoft MPI,
212212
* closure cfunctions in Julia are based on LLVM trampolines, which are not supported on ARM architecture.
213+
214+
As an alternative [`MPI.@RegisterOp`](@ref) may be used to statically register reduction operations.

docs/src/reference/advanced.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ MPI.Types.duplicate
2626

2727
```@docs
2828
MPI.Op
29+
MPI.@RegisterOp
2930
```
3031

3132
## Info objects

src/operators.jl

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ associative, and if `iscommutative` is true, assumed to be commutative as well.
1616
- [`Allreduce!`](@ref)/[`Allreduce`](@ref)
1717
- [`Scan!`](@ref)/[`Scan`](@ref)
1818
- [`Exscan!`](@ref)/[`Exscan`](@ref)
19+
- [`@RegisterOp`](@ref)
1920
"""
2021
mutable struct Op
2122
val::MPI_Op
@@ -81,21 +82,36 @@ end
8182

8283
function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T}
8384
len = unsafe_load(_len)
84-
@assert isconcretetype(T)
85-
a = Ptr{T}(_a)
86-
b = Ptr{T}(_b)
87-
for i = 1:len
88-
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
85+
if !isconcretetype(T)
86+
concrete_T = to_type(Datatype(unsafe_load(t))) # Ptr might actually point to a Julia object so we could unsafe_pointer_to_objref?
87+
else
88+
concrete_T = T
8989
end
90+
function copy(::Type{T}) where T
91+
@assert isconcretetype(T)
92+
a = Ptr{T}(_a)
93+
b = Ptr{T}(_b)
94+
for i = 1:len
95+
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
96+
end
97+
end
98+
copy(concrete_T)
9099
return nothing
91100
end
92101

93-
94102
function Op(f, T=Any; iscommutative=false)
95103
@static if MPI_LIBRARY == "MicrosoftMPI" && Sys.WORD_SIZE == 32
96-
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
104+
error("""
105+
User-defined reduction operators are not supported on 32-bit Windows.
106+
See https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.
107+
""")
97108
elseif Sys.ARCH (:aarch64, :ppc64le, :powerpc64le) || startswith(lowercase(String(Sys.ARCH)), "arm")
98-
error("User-defined reduction operators are currently not supported on non-Intel architectures.\nSee https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.")
109+
error("""
110+
User-defined reduction operators are currently not supported on non-Intel architectures.
111+
See https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.
112+
113+
You may want to use `@RegisterOp` to statically register `f`.
114+
""")
99115
end
100116
w = OpWrapper{typeof(f),T}(f)
101117
fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
@@ -107,3 +123,70 @@ function Op(f, T=Any; iscommutative=false)
107123
finalizer(free, op)
108124
return op
109125
end
126+
127+
"""
128+
@RegisterOp(f, T)
129+
130+
Register a custom operator [`Op`](@ref) using the function `f` statically.
131+
On platfroms like AArch64, Julia does not support runtime closures,
132+
being passed to C. The generic version of [`Op`](@ref) uses runtime closures
133+
to support arbitrary functions being passed as MPI reduction operators.
134+
`@RegisterOp` statically adds a function to the set of functions allowed as
135+
as an MPI operator.
136+
137+
```julia
138+
function my_reduce(x, y)
139+
2x+y-x
140+
end
141+
MPI.@RegisterOp(my_reduce, Int)
142+
# ...
143+
MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root)
144+
#...
145+
```
146+
!!! warning
147+
Note that `@RegisterOp` works be introducing a new method of the generic function `Op`.
148+
It can only be used as a top-level statement and may trigger method invalidations.
149+
150+
!!! note
151+
`T` can be `Any`, but this will lead to a runtime dispatch.
152+
"""
153+
macro RegisterOp(f, T)
154+
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
155+
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
156+
name_module = gensym(Symbol(f, :_, T, :_module))
157+
# The gist is that we can use a method very similar to how we handle `min`/`max`
158+
# but since this might be used from user code we can't use add_load_time_hook!
159+
# this is why we introduce a new module that has a `__init__` function.
160+
# If this module approach is too costly for loading MPI.jl for internal use we could use
161+
# `add_load_time_hook`
162+
expr = quote
163+
module $(name_module)
164+
# import ..$f, ..$T
165+
$(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T
166+
const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
167+
const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})))
168+
function __init__()
169+
$(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
170+
end
171+
import MPI: Op
172+
# we can't create a const Op since MPI needs to be initialized?
173+
function Op(::typeof($f), ::Type{<:$T}; iscommutative=false)
174+
op = Op($OP_NULL.val, $(name_fptr)[])
175+
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
176+
$API.MPI_Op_create($(name_fptr)[], iscommutative, op)
177+
178+
finalizer($free, op)
179+
end
180+
end
181+
end
182+
expr.head = :toplevel
183+
esc(expr)
184+
end
185+
186+
@RegisterOp(min, Any)
187+
@RegisterOp(max, Any)
188+
@RegisterOp(+, Any)
189+
@RegisterOp(*, Any)
190+
@RegisterOp(&, Any)
191+
@RegisterOp(|, Any)
192+
@RegisterOp(, Any)

test/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
66
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
77
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
88

9+
[weakdeps]
10+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
11+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
12+
913
[compat]
1014
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
1115
CUDA = "3, 4, 5"
@@ -16,7 +20,3 @@ TOML = "< 0.0.1, 1.0"
1620
[extras]
1721
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1822
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
19-
20-
[weakdeps]
21-
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
22-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

test/test_reduce.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ if isroot
5959
@test sum_mesg == sz .* mesg
6060
end
6161

62+
function my_reduce(x, y)
63+
2x+y-x
64+
end
65+
MPI.@RegisterOp(my_reduce, Any)
66+
6267
if can_do_closures
63-
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
68+
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
6469
else
65-
operators = [MPI.SUM, +]
70+
operators = [MPI.SUM, +, my_reduce]
6671
end
6772

6873
for T = [Int]
@@ -117,19 +122,17 @@ end
117122

118123
MPI.Barrier( MPI.COMM_WORLD )
119124

120-
if can_do_closures
121-
send_arr = [Double64(i)/10 for i = 1:10]
122-
123-
result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
124-
if rank == root
125-
@test result [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
126-
else
127-
@test result === nothing
128-
end
125+
send_arr = [Double64(i)/10 for i = 1:10]
129126

130-
MPI.Barrier( MPI.COMM_WORLD )
127+
result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
128+
if rank == root
129+
@test result [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
130+
else
131+
@test result === nothing
131132
end
132133

134+
MPI.Barrier( MPI.COMM_WORLD )
135+
133136
GC.gc()
134137
MPI.Finalize()
135138
@test MPI.Finalized()

0 commit comments

Comments
 (0)