Skip to content

Commit 81954b6

Browse files
sync parallel build functions
Fixes #487
1 parent 0aea62f commit 81954b6

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

src/build_function.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function _build_and_inject_function(mod::Module, ex)
140140
params = typeof(runtimefn).parameters
141141
fn_expr = GeneralizedGenerated.NGG.from_type(params[3])
142142

143-
# Inject our externally registered module functions
143+
# Inject our externally registered module functions
144144
new_expr = ModelingToolkit.inject_registered_module_functions(fn_expr)
145145

146146
# Reconstruct the RuntimeFn's Body
@@ -273,6 +273,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
273273
end
274274
end)
275275
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
276+
ip_let_expr = :(@sync begin $ip_let_expr end)
276277
elseif parallel isa DistributedForm
277278
numworks = Distributed.nworkers()
278279
lens = Int(ceil(length(ip_let_expr.args[2].args)/numworks))
@@ -291,9 +292,11 @@ function _build_function(target::JuliaTarget, rhss, args...;
291292
resunpack_exprs = [:($(Symbol(reducevars[iter])) = fetch($(spawnvars[iter]))) for iter in 1:numworks]
292293

293294
ip_let_expr.args[2] = quote
294-
$spawn_exprs
295-
$(resunpack_exprs...)
296-
$(ip_let_expr.args[2])
295+
@sync begin
296+
$spawn_exprs
297+
$(resunpack_exprs...)
298+
$(ip_let_expr.args[2])
299+
end
297300
end
298301
elseif parallel isa DaggerForm
299302
@assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`"
@@ -308,9 +311,11 @@ function _build_function(target::JuliaTarget, rhss, args...;
308311
$(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...)))
309312
end
310313
ip_let_expr.args[2] = quote
311-
$delayed_exprs
312-
$reduce_expr
313-
$(ip_let_expr.args[2])
314+
@sync begin
315+
$delayed_exprs
316+
$reduce_expr
317+
$(ip_let_expr.args[2])
318+
end
314319
end
315320
end
316321

test/bigsystem.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,35 @@ end
4646
f(du,u,nothing,0.0)
4747

4848
multithreadedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.MultithreadedForm())[2])
49-
_du = rand(N,N,3)
50-
_u = rand(N,N,3)
51-
multithreadedf(_du,_u)
49+
50+
# Loop to catch syncronization issues
51+
for i in 1:100
52+
_du = rand(N,N,3)
53+
_u = rand(N,N,3)
54+
multithreadedf(_du,_u)
55+
_du2 = copy(_du)
56+
f(_du2,_u,nothing,0.0)
57+
@test _du _du2
58+
end
59+
60+
#=
61+
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u)))
62+
fjac = eval(ModelingToolkit.build_function(jac,u,parallel=ModelingToolkit.SerialForm())[2])
63+
multithreadedfjac = eval(ModelingToolkit.build_function(jac,u,parallel=ModelingToolkit.MultithreadedForm())[2])
64+
65+
u = rand(N,N,3)
66+
J = similar(jac,Float64)
67+
fjac(J,u)
68+
69+
J2 = similar(jac,Float64)
70+
multithreadedfjac(J2,u)
71+
@test J ≈ J2
72+
73+
using FiniteDiff
74+
J3 = Array(similar(jac,Float64))
75+
FiniteDiff.finite_difference_jacobian!(J2,(du,u)->f!(du,u,nothing,nothing),u)
76+
maximum(J2 .- Array(J)) < 1e-5
77+
=#
5278

5379
using Distributed
5480
addprocs(4)

0 commit comments

Comments
 (0)