Skip to content

Commit b9ada6c

Browse files
Merge pull request #488 from SciML/sync
sync parallel build functions
2 parents 0aea62f + 8a8fdb1 commit b9ada6c

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

src/build_function.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function _build_and_inject_function(mod::Module, ex)
140140
params = typeof(runtimefn).parameters
141141
fn_expr = GeneralizedGenerated.NGG.from_type(params[3])
142142

143-
# Inject our externally registered module functions
143+
# Inject our externally registered module functions
144144
new_expr = ModelingToolkit.inject_registered_module_functions(fn_expr)
145145

146146
# Reconstruct the RuntimeFn's Body
@@ -273,6 +273,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
273273
end
274274
end)
275275
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
276+
ip_let_expr = :(@sync begin $ip_let_expr end)
276277
elseif parallel isa DistributedForm
277278
numworks = Distributed.nworkers()
278279
lens = Int(ceil(length(ip_let_expr.args[2].args)/numworks))
@@ -291,9 +292,11 @@ function _build_function(target::JuliaTarget, rhss, args...;
291292
resunpack_exprs = [:($(Symbol(reducevars[iter])) = fetch($(spawnvars[iter]))) for iter in 1:numworks]
292293

293294
ip_let_expr.args[2] = quote
294-
$spawn_exprs
295-
$(resunpack_exprs...)
296-
$(ip_let_expr.args[2])
295+
@sync begin
296+
$spawn_exprs
297+
$(resunpack_exprs...)
298+
$(ip_let_expr.args[2])
299+
end
297300
end
298301
elseif parallel isa DaggerForm
299302
@assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`"
@@ -308,9 +311,11 @@ function _build_function(target::JuliaTarget, rhss, args...;
308311
$(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...)))
309312
end
310313
ip_let_expr.args[2] = quote
311-
$delayed_exprs
312-
$reduce_expr
313-
$(ip_let_expr.args[2])
314+
@sync begin
315+
$delayed_exprs
316+
$reduce_expr
317+
$(ip_let_expr.args[2])
318+
end
314319
end
315320
end
316321

test/bigsystem.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,38 @@ end
4646
f(du,u,nothing,0.0)
4747

4848
multithreadedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.MultithreadedForm())[2])
49-
_du = rand(N,N,3)
50-
_u = rand(N,N,3)
51-
multithreadedf(_du,_u)
49+
50+
MyA = zeros(N,N);
51+
AMx = zeros(N,N);
52+
DA = zeros(N,N);
53+
# Loop to catch syncronization issues
54+
for i in 1:100
55+
_du = rand(N,N,3)
56+
_u = rand(N,N,3)
57+
multithreadedf(_du,_u)
58+
_du2 = copy(_du)
59+
f(_du2,_u,nothing,0.0)
60+
@test _du _du2
61+
end
62+
63+
#=
64+
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u)))
65+
fjac = eval(ModelingToolkit.build_function(jac,u,parallel=ModelingToolkit.SerialForm())[2])
66+
multithreadedfjac = eval(ModelingToolkit.build_function(jac,u,parallel=ModelingToolkit.MultithreadedForm())[2])
67+
68+
u = rand(N,N,3)
69+
J = similar(jac,Float64)
70+
fjac(J,u)
71+
72+
J2 = similar(jac,Float64)
73+
multithreadedfjac(J2,u)
74+
@test J ≈ J2
75+
76+
using FiniteDiff
77+
J3 = Array(similar(jac,Float64))
78+
FiniteDiff.finite_difference_jacobian!(J2,(du,u)->f!(du,u,nothing,nothing),u)
79+
maximum(J2 .- Array(J)) < 1e-5
80+
=#
5281

5382
using Distributed
5483
addprocs(4)
@@ -66,6 +95,8 @@ daggerjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingTool
6695
MyA = zeros(N,N)
6796
AMx = zeros(N,N)
6897
DA = zeros(N,N)
98+
_du = rand(N,N,3)
99+
_u = rand(N,N,3)
69100

70101
f(_du,_u,nothing,0.0)
71102
multithreadedf(_du,_u)

0 commit comments

Comments
 (0)