1
1
using SymbolicUtils. Code
2
+ using Base. Threads
2
3
3
4
abstract type BuildTargets end
4
5
struct JuliaTarget <: BuildTargets end
@@ -128,7 +129,7 @@ Build function target: JuliaTarget
128
129
function _build_function(target::JuliaTarget, rhss, args...;
129
130
conv = toexpr, expression = Val{true},
130
131
checkbounds = false,
131
- linenumbers = false, multithread=nothing,
132
+ linenumbers = false,
132
133
headerfun = addheader, outputidxs=nothing,
133
134
convert_oop = true, force_SA = false,
134
135
skipzeros = outputidxs===nothing,
@@ -182,7 +183,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
182
183
expression = Val{true },
183
184
expression_module = @__MODULE__ (),
184
185
checkbounds = false ,
185
- linenumbers = false , multithread = nothing ,
186
+ linenumbers = false ,
186
187
outputidxs= nothing ,
187
188
skipzeros = false ,
188
189
wrap_code = (nothing , nothing ),
@@ -192,13 +193,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
192
193
dargs = map (destructure_arg, [args... ])
193
194
i = findfirst (x-> x isa DestructuredArgs, dargs)
194
195
similarto = i === nothing ? Array : dargs[i]. name
195
- oop_expr = Func (dargs, [], _make_array ( rhss, similarto))
196
+ oop_expr = Func (dargs, [], make_array (parallel, rhss, similarto))
196
197
if ! isnothing (wrap_code[1 ])
197
198
oop_expr = wrap_code[1 ](oop_expr)
198
199
end
199
200
200
201
out = Sym {Any} (gensym (" out" ))
201
- ip_expr = Func ([out, dargs... ], [], _set_array ( out, outputidxs, rhss, checkbounds, skipzeros))
202
+ ip_expr = Func ([out, dargs... ], [], set_array (parallel, out, outputidxs, rhss, checkbounds, skipzeros))
202
203
203
204
if ! isnothing (wrap_code[2 ])
204
205
ip_expr = wrap_code[2 ](ip_expr)
@@ -212,6 +213,25 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
212
213
end
213
214
end
214
215
216
+ function make_array (s, arr, similarto)
217
+ Base. @warn (" Parallel form of $(typeof (s)) not implemented" )
218
+ _make_array (arr, similarto)
219
+ end
220
+
221
+ function make_array (s:: SerialForm , arr, similarto)
222
+ _make_array (arr, similarto)
223
+ end
224
+
225
+ function make_array (s:: MultithreadedForm , arr, similarto)
226
+ ntasks = 2 * nthreads () # oversubscribe a little bit
227
+ per_task = ceil (Int, length (arr) / ntasks)
228
+ slices = collect (Iterators. partition (arr, per_task))
229
+ arrays = map (slices) do slice
230
+ _make_array (slice, similarto)
231
+ end
232
+ Par {Multithreaded} (arrays, vcat)
233
+ end
234
+
215
235
function _make_array (rhss:: AbstractSparseArray , similarto)
216
236
arr = map (x-> _make_array (x, similarto), rhss)
217
237
if ! (arr isa AbstractSparseArray)
@@ -221,19 +241,59 @@ function _make_array(rhss::AbstractSparseArray, similarto)
221
241
end
222
242
end
223
243
244
+ function _make_array (rhss:: AbstractArray , similarto)
245
+ arr = map (x-> _make_array (x, similarto), rhss)
246
+ # Ugh reshaped array of a sparse array when mapped gives a sparse array
247
+ if arr isa AbstractSparseArray
248
+ _make_array (arr, similarto)
249
+ else
250
+ MakeArray (arr, similarto)
251
+ end
252
+ end
253
+
254
+ _make_array (x, similarto) = x
255
+
224
256
# # In-place version
225
- function _set_array (out, outputidxs, rhss:: AbstractArray , checkbounds, skipzeros)
226
- if rhss isa Union{SparseVector, SparseMatrixCSC}
227
- return SetArray (checkbounds, LiteralExpr (:($ out. nzval)), rhss. nzval)
228
- elseif outputidxs === nothing
257
+
258
+ function set_array (p, args... )
259
+ Base. @warn (" Parallel form of $(typeof (p)) not implemented" )
260
+ _set_array (args... )
261
+ end
262
+
263
+ function set_array (s:: SerialForm , args... )
264
+ _set_array (args... )
265
+ end
266
+
267
+ function set_array (s:: MultithreadedForm , out, outputidxs, rhss, checkbounds, skipzeros)
268
+ if rhss isa AbstractSparseArray
269
+ return set_array (LiteralExpr (:($ out. nzval)),
270
+ nothing ,
271
+ rhss. nzval,
272
+ checkbounds,
273
+ skipzeros)
274
+ end
275
+ if outputidxs === nothing
229
276
outputidxs = collect (eachindex (rhss))
230
277
end
278
+ ntasks = 2 * nthreads () # oversubscribe a little bit
279
+ per_task = ceil (Int, length (rhss) / ntasks)
280
+ slices = collect (Iterators. partition (zip (outputidxs, rhss), per_task))
281
+ arrays = map (slices) do slice
282
+ idxs, vals = first .(slice), last .(slice)
283
+ _set_array (out, idxs, vals, checkbounds, skipzeros)
284
+ end
285
+ Par {Multithreaded} (arrays, @inline f (args... ) = nothing )
286
+ end
231
287
288
+ function _set_array (out, outputidxs, rhss:: AbstractArray , checkbounds, skipzeros)
289
+ if outputidxs === nothing
290
+ outputidxs = collect (eachindex (rhss))
291
+ end
232
292
# sometimes outputidxs is a Tuple
233
293
ii = findall (i-> ! (rhss[i] isa AbstractArray) && ! (skipzeros && _iszero (rhss[i])), eachindex (outputidxs))
234
294
jj = findall (i-> rhss[i] isa AbstractArray, eachindex (outputidxs))
235
295
exprs = []
236
- push! (exprs, SetArray (checkbounds, out, AtIndex .(vec (collect (outputidxs[ii])), vec (rhss[ii]))))
296
+ push! (exprs, SetArray (! checkbounds, out, AtIndex .(vec (collect (outputidxs[ii])), vec (rhss[ii]))))
237
297
for j in jj
238
298
push! (exprs, _set_array (LiteralExpr (:($ out[$ j])), nothing , rhss[j], checkbounds, skipzeros))
239
299
end
245
305
_set_array (out, outputidxs, rhs, checkbounds, skipzeros) = rhs
246
306
247
307
248
- function _make_array (rhss:: AbstractArray , similarto)
249
- arr = map (x-> _make_array (x, similarto), rhss)
250
- # Ugh reshaped array of a sparse array when mapped gives a sparse array
251
- if arr isa AbstractSparseArray
252
- _make_array (arr, similarto)
253
- else
254
- MakeArray (arr, similarto)
255
- end
256
- end
257
-
258
- _make_array (x, similarto) = x
259
-
260
308
function vars_to_pairs (name,vs:: Union{Tuple, AbstractArray} , symsdict= Dict ())
261
309
vs_names = tosymbol .(vs)
262
310
for (v,k) in zip (vs_names, vs)
0 commit comments