Skip to content

Commit e2f4799

Browse files
committed
multi-threading
1 parent 75610c1 commit e2f4799

File tree

1 file changed

+69
-21
lines changed

1 file changed

+69
-21
lines changed

src/build_function.jl

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SymbolicUtils.Code
2+
using Base.Threads
23

34
abstract type BuildTargets end
45
struct JuliaTarget <: BuildTargets end
@@ -128,7 +129,7 @@ Build function target: JuliaTarget
128129
function _build_function(target::JuliaTarget, rhss, args...;
129130
conv = toexpr, expression = Val{true},
130131
checkbounds = false,
131-
linenumbers = false, multithread=nothing,
132+
linenumbers = false,
132133
headerfun = addheader, outputidxs=nothing,
133134
convert_oop = true, force_SA = false,
134135
skipzeros = outputidxs===nothing,
@@ -182,7 +183,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
182183
expression = Val{true},
183184
expression_module = @__MODULE__(),
184185
checkbounds = false,
185-
linenumbers = false, multithread=nothing,
186+
linenumbers = false,
186187
outputidxs=nothing,
187188
skipzeros = false,
188189
wrap_code = (nothing, nothing),
@@ -192,13 +193,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
192193
dargs = map(destructure_arg, [args...])
193194
i = findfirst(x->x isa DestructuredArgs, dargs)
194195
similarto = i === nothing ? Array : dargs[i].name
195-
oop_expr = Func(dargs, [], _make_array(rhss, similarto))
196+
oop_expr = Func(dargs, [], make_array(parallel, rhss, similarto))
196197
if !isnothing(wrap_code[1])
197198
oop_expr = wrap_code[1](oop_expr)
198199
end
199200

200201
out = Sym{Any}(gensym("out"))
201-
ip_expr = Func([out, dargs...], [], _set_array(out, outputidxs, rhss, checkbounds, skipzeros))
202+
ip_expr = Func([out, dargs...], [], set_array(parallel, out, outputidxs, rhss, checkbounds, skipzeros))
202203

203204
if !isnothing(wrap_code[2])
204205
ip_expr = wrap_code[2](ip_expr)
@@ -212,6 +213,25 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
212213
end
213214
end
214215

216+
function make_array(s, arr, similarto)
217+
Base.@warn("Parallel form of $(typeof(s)) not implemented")
218+
_make_array(arr, similarto)
219+
end
220+
221+
function make_array(s::SerialForm, arr, similarto)
222+
_make_array(arr, similarto)
223+
end
224+
225+
function make_array(s::MultithreadedForm, arr, similarto)
226+
ntasks = 2 * nthreads() # oversubscribe a little bit
227+
per_task = ceil(Int, length(arr) / ntasks)
228+
slices = collect(Iterators.partition(arr, per_task))
229+
arrays = map(slices) do slice
230+
_make_array(slice, similarto)
231+
end
232+
Par{Multithreaded}(arrays, vcat)
233+
end
234+
215235
function _make_array(rhss::AbstractSparseArray, similarto)
216236
arr = map(x->_make_array(x, similarto), rhss)
217237
if !(arr isa AbstractSparseArray)
@@ -221,19 +241,59 @@ function _make_array(rhss::AbstractSparseArray, similarto)
221241
end
222242
end
223243

244+
function _make_array(rhss::AbstractArray, similarto)
245+
arr = map(x->_make_array(x, similarto), rhss)
246+
# Ugh reshaped array of a sparse array when mapped gives a sparse array
247+
if arr isa AbstractSparseArray
248+
_make_array(arr, similarto)
249+
else
250+
MakeArray(arr, similarto)
251+
end
252+
end
253+
254+
_make_array(x, similarto) = x
255+
224256
## In-place version
225-
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
226-
if rhss isa Union{SparseVector, SparseMatrixCSC}
227-
return SetArray(checkbounds, LiteralExpr(:($out.nzval)), rhss.nzval)
228-
elseif outputidxs === nothing
257+
258+
function set_array(p, args...)
259+
Base.@warn("Parallel form of $(typeof(p)) not implemented")
260+
_set_array(args...)
261+
end
262+
263+
function set_array(s::SerialForm, args...)
264+
_set_array(args...)
265+
end
266+
267+
function set_array(s::MultithreadedForm, out, outputidxs, rhss, checkbounds, skipzeros)
268+
if rhss isa AbstractSparseArray
269+
return set_array(LiteralExpr(:($out.nzval)),
270+
nothing,
271+
rhss.nzval,
272+
checkbounds,
273+
skipzeros)
274+
end
275+
if outputidxs === nothing
229276
outputidxs = collect(eachindex(rhss))
230277
end
278+
ntasks = 2 * nthreads() # oversubscribe a little bit
279+
per_task = ceil(Int, length(rhss) / ntasks)
280+
slices = collect(Iterators.partition(zip(outputidxs, rhss), per_task))
281+
arrays = map(slices) do slice
282+
idxs, vals = first.(slice), last.(slice)
283+
_set_array(out, idxs, vals, checkbounds, skipzeros)
284+
end
285+
Par{Multithreaded}(arrays, @inline f(args...) = nothing)
286+
end
231287

288+
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
289+
if outputidxs === nothing
290+
outputidxs = collect(eachindex(rhss))
291+
end
232292
# sometimes outputidxs is a Tuple
233293
ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && _iszero(rhss[i])), eachindex(outputidxs))
234294
jj = findall(i->rhss[i] isa AbstractArray, eachindex(outputidxs))
235295
exprs = []
236-
push!(exprs, SetArray(checkbounds, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii]))))
296+
push!(exprs, SetArray(!checkbounds, out, AtIndex.(vec(collect(outputidxs[ii])), vec(rhss[ii]))))
237297
for j in jj
238298
push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], checkbounds, skipzeros))
239299
end
@@ -245,18 +305,6 @@ end
245305
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = rhs
246306

247307

248-
function _make_array(rhss::AbstractArray, similarto)
249-
arr = map(x->_make_array(x, similarto), rhss)
250-
# Ugh reshaped array of a sparse array when mapped gives a sparse array
251-
if arr isa AbstractSparseArray
252-
_make_array(arr, similarto)
253-
else
254-
MakeArray(arr, similarto)
255-
end
256-
end
257-
258-
_make_array(x, similarto) = x
259-
260308
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
261309
vs_names = tosymbol.(vs)
262310
for (v,k) in zip(vs_names, vs)

0 commit comments

Comments
 (0)