Skip to content

Commit ae4f53c

Browse files
authored
Merge pull request #798 from SciML/s/threading
Bring back multithreading
2 parents f0fefb1 + e67934e commit ae4f53c

File tree

7 files changed

+84
-68
lines changed

7 files changed

+84
-68
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ Unitful = "1.1"
6767
julia = "1.2"
6868

6969
[extras]
70-
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
7170
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7271
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
7372
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -79,4 +78,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
7978
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8079

8180
[targets]
82-
test = ["Dagger", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]
81+
test = ["ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]

docs/src/comparison.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ excels in many areas due to purposeful design decisions:
2424
- Parallelism: ModelingToolkit.jl has pervasive parallelism. The
2525
symbolic simplification via [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl)
2626
has built-in parallelism, ModelingToolkit.jl builds functions that
27-
parallelizes across threads and multiprocesses across clusters,
28-
and it has dynamic scheduling through tools like [Dagger.jl](https://github.com/JuliaParallel/Dagger.jl).
29-
ModelingToolkit.jl is compatible with GPU libraries like CUDA.jl.
27+
parallelizes across threads. ModelingToolkit.jl is compatible with GPU libraries like CUDA.jl.
3028
- Scientific Machine Learning (SciML): ModelingToolkit.jl is made to synergize
3129
with the high performance Julia SciML ecosystem in many ways. At a
3230
base level, all expressions and built functions are compatible with

src/ModelingToolkit.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,4 @@ export @register
296296
export modelingtoolkitize
297297
export @variables, @parameters
298298

299-
const HAS_DAGGER = Ref{Bool}(false)
300-
function __init__()
301-
@require Dagger="d58978e5-989f-55fb-8d15-ea34adc7bf54" include("dagger.jl")
302-
end
303-
304299
end # module

src/build_function.jl

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SymbolicUtils.Code
2+
using Base.Threads
23

34
abstract type BuildTargets end
45
struct JuliaTarget <: BuildTargets end
@@ -8,9 +9,10 @@ struct MATLABTarget <: BuildTargets end
89

910
abstract type ParallelForm end
1011
struct SerialForm <: ParallelForm end
11-
struct MultithreadedForm <: ParallelForm end
12-
struct DistributedForm <: ParallelForm end
13-
struct DaggerForm <: ParallelForm end
12+
struct MultithreadedForm <: ParallelForm
13+
ntasks::Int
14+
end
15+
MultithreadedForm() = MultithreadedForm(2*nthreads())
1416

1517
"""
1618
`build_function`
@@ -60,18 +62,6 @@ function unflatten_long_ops(op, N=4)
6062
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
6163
end
6264

63-
function observed_let(eqs)
64-
process -> ex -> begin
65-
isempty(eqs) && return ex
66-
67-
assignments = map(eq -> :($(process(eq.lhs)) = $(process(eq.rhs))), eqs)
68-
letexpr = :(let $(assignments...)
69-
end)
70-
# avoid a superfluous `begin ... end` block
71-
letexpr.args[2] = ex
72-
return letexpr
73-
end
74-
end
7565

7666
# Scalar output
7767

@@ -140,7 +130,7 @@ Build function target: JuliaTarget
140130
function _build_function(target::JuliaTarget, rhss, args...;
141131
conv = toexpr, expression = Val{true},
142132
checkbounds = false,
143-
linenumbers = false, multithread=nothing,
133+
linenumbers = false,
144134
headerfun = addheader, outputidxs=nothing,
145135
convert_oop = true, force_SA = false,
146136
skipzeros = outputidxs===nothing,
@@ -168,10 +158,6 @@ Special Keyword Argumnets:
168158
- `SerialForm()`: Serial execution.
169159
- `MultithreadedForm()`: Multithreaded execution with a static split, evenly
170160
splitting the number of expressions per thread.
171-
- `DistributedForm()`: Multiprocessing using Julia's Distributed with a static
172-
schedule, evenly splitting the number of expressions per process.
173-
- `DaggerForm()`: Multithreading and multiprocessing using Julia's Dagger.jl
174-
for dynamic scheduling and load balancing.
175161
- `conv`: The conversion function of the Operation to Expr. By default this uses
176162
the `toexpr` function.
177163
- `checkbounds`: For whether to enable bounds checking inside of the generated
@@ -194,7 +180,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
194180
expression = Val{true},
195181
expression_module = @__MODULE__(),
196182
checkbounds = false,
197-
linenumbers = false, multithread=nothing,
183+
linenumbers = false,
198184
outputidxs=nothing,
199185
skipzeros = false,
200186
wrap_code = (nothing, nothing),
@@ -204,13 +190,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
204190
dargs = map(destructure_arg, [args...])
205191
i = findfirst(x->x isa DestructuredArgs, dargs)
206192
similarto = i === nothing ? Array : dargs[i].name
207-
oop_expr = Func(dargs, [], _make_array(rhss, similarto))
193+
oop_expr = Func(dargs, [], make_array(parallel, rhss, similarto))
208194
if !isnothing(wrap_code[1])
209195
oop_expr = wrap_code[1](oop_expr)
210196
end
211197

212198
out = Sym{Any}(gensym("out"))
213-
ip_expr = Func([out, dargs...], [], _set_array(out, outputidxs, rhss, checkbounds, skipzeros))
199+
ip_expr = Func([out, dargs...], [], set_array(parallel, out, outputidxs, rhss, checkbounds, skipzeros))
214200

215201
if !isnothing(wrap_code[2])
216202
ip_expr = wrap_code[2](ip_expr)
@@ -224,6 +210,24 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
224210
end
225211
end
226212

213+
function make_array(s, arr, similarto)
214+
Base.@warn("Parallel form of $(typeof(s)) not implemented")
215+
_make_array(arr, similarto)
216+
end
217+
218+
function make_array(s::SerialForm, arr, similarto)
219+
_make_array(arr, similarto)
220+
end
221+
222+
function make_array(s::MultithreadedForm, arr, similarto)
223+
per_task = ceil(Int, length(arr) / s.ntasks)
224+
slices = collect(Iterators.partition(arr, per_task))
225+
arrays = map(slices) do slice
226+
_make_array(slice, similarto)
227+
end
228+
SpawnFetch{Multithreaded}(arrays, vcat)
229+
end
230+
227231
function _make_array(rhss::AbstractSparseArray, similarto)
228232
arr = map(x->_make_array(x, similarto), rhss)
229233
if !(arr isa AbstractSparseArray)
@@ -233,19 +237,59 @@ function _make_array(rhss::AbstractSparseArray, similarto)
233237
end
234238
end
235239

240+
function _make_array(rhss::AbstractArray, similarto)
241+
arr = map(x->_make_array(x, similarto), rhss)
242+
# Ugh reshaped array of a sparse array when mapped gives a sparse array
243+
if arr isa AbstractSparseArray
244+
_make_array(arr, similarto)
245+
else
246+
MakeArray(arr, similarto)
247+
end
248+
end
249+
250+
_make_array(x, similarto) = x
251+
236252
## In-place version
237-
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
238-
if rhss isa Union{SparseVector, SparseMatrixCSC}
239-
return SetArray(checkbounds, LiteralExpr(:($out.nzval)), rhss.nzval)
240-
elseif outputidxs === nothing
253+
254+
function set_array(p, args...)
255+
Base.@warn("Parallel form of $(typeof(p)) not implemented")
256+
_set_array(args...)
257+
end
258+
259+
function set_array(s::SerialForm, args...)
260+
_set_array(args...)
261+
end
262+
263+
function set_array(s::MultithreadedForm, out, outputidxs, rhss, checkbounds, skipzeros)
264+
if rhss isa AbstractSparseArray
265+
return set_array(LiteralExpr(:($out.nzval)),
266+
nothing,
267+
rhss.nzval,
268+
checkbounds,
269+
skipzeros)
270+
end
271+
if outputidxs === nothing
241272
outputidxs = collect(eachindex(rhss))
242273
end
274+
per_task = ceil(Int, length(rhss) / s.ntasks)
275+
# TODO: do better partitioning when skipzeros is present
276+
slices = collect(Iterators.partition(zip(outputidxs, rhss), per_task))
277+
arrays = map(slices) do slice
278+
idxs, vals = first.(slice), last.(slice)
279+
_set_array(out, idxs, vals, checkbounds, skipzeros)
280+
end
281+
SpawnFetch{Multithreaded}(arrays, @inline noop(args...) = nothing)
282+
end
243283

284+
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
285+
if outputidxs === nothing
286+
outputidxs = collect(eachindex(rhss))
287+
end
244288
# sometimes outputidxs is a Tuple
245289
ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && _iszero(rhss[i])), eachindex(outputidxs))
246290
jj = findall(i->rhss[i] isa AbstractArray, eachindex(outputidxs))
247291
exprs = []
248-
push!(exprs, SetArray(checkbounds, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii]))))
292+
push!(exprs, SetArray(!checkbounds, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii]))))
249293
for j in jj
250294
push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], checkbounds, skipzeros))
251295
end
@@ -257,18 +301,6 @@ end
257301
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = rhs
258302

259303

260-
function _make_array(rhss::AbstractArray, similarto)
261-
arr = map(x->_make_array(x, similarto), rhss)
262-
# Ugh reshaped array of a sparse array when mapped gives a sparse array
263-
if arr isa AbstractSparseArray
264-
_make_array(arr, similarto)
265-
else
266-
MakeArray(arr, similarto)
267-
end
268-
end
269-
270-
_make_array(x, similarto) = x
271-
272304
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
273305
vs_names = tosymbol.(vs)
274306
for (v,k) in zip(vs_names, vs)

src/dagger.jl

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/bigsystem.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,9 @@ FiniteDiff.finite_difference_jacobian!(J2,(du,u)->f!(du,u,nothing,nothing),u)
8080
maximum(J2 .- Array(J)) < 1e-5
8181
=#
8282

83-
using Distributed
84-
addprocs(4)
85-
distributedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DistributedForm())[2])
86-
87-
using Dagger
88-
daggerf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DaggerForm())[2])
89-
9083
jac = ModelingToolkit.sparsejacobian(vec(du),vec(u))
9184
serialjac = eval(ModelingToolkit.build_function(vec(jac),u)[2])
9285
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.MultithreadedForm())[2])
93-
distributedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DistributedForm())[2])
94-
daggerjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DaggerForm())[2])
9586

9687
MyA = zeros(N,N)
9788
AMx = zeros(N,N)
@@ -101,19 +92,13 @@ _u = rand(N,N,3)
10192

10293
f(_du,_u,nothing,0.0)
10394
multithreadedf(_du,_u)
104-
#distributedf(_du,_u)
105-
#daggerf(_du,_u)
10695

10796
#=
10897
using BenchmarkTools
10998
@btime f(_du,_u,nothing,0.0)
11099
@btime multithreadedf(_du,_u)
111-
@btime distributedf(_du,_u)
112-
@btime daggerf(_du,_u)
113100
114101
_jac = similar(jac,Float64)
115102
@btime serialjac(_jac,_u)
116103
@btime multithreadedjac(_jac,_u)
117-
@btime distributedjac(_jac,_u)
118-
@btime daggerjac(_jac,_u)
119104
=#

test/build_function.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ end
1414

1515
h_str = ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g])
1616
h_oop = eval(h_str[1])
17+
h_str_par = ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], parallel=ModelingToolkit.MultithreadedForm())
18+
h_oop_par = eval(h_str_par[1])
1719
h_ip! = eval(h_str[2])
1820
h_ip_skip! = eval(ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], skipzeros=true, fillzeros=false)[2])
21+
h_ip_skip_par! = eval(ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], skipzeros=true, parallel=ModelingToolkit.MultithreadedForm(), fillzeros=false)[2])
1922
inputs = ([1], [2], [3, 4, 5], [6], [7], [8])
2023

2124
@test h_oop(inputs...) == h_julia(inputs...)
25+
@test h_oop_par(inputs...) == h_julia(inputs...)
2226
out_1 = similar(h, Int)
2327
out_2 = similar(out_1)
2428
h_ip!(out_1, inputs...)
@@ -30,6 +34,12 @@ h_ip_skip!(out_1, inputs...)
3034
out_1[3] = 0
3135
@test out_1 == out_2
3236

37+
fill!(out_1, 10)
38+
h_ip_skip_par!(out_1, inputs...)
39+
@test out_1[3] == 10
40+
out_1[3] = 0
41+
@test out_1 == out_2
42+
3343
# Multiple input matrix, some unused arguments
3444
h_skip = [a + b + c1; c2 + c3 + g] # skip d, e
3545
h_julia_skip(a, b, c, d, e, g) = [a[1] + b[1] + c[1]; c[2] + c[3] + g[1]]

0 commit comments

Comments
 (0)