1
1
using SymbolicUtils. Code
2
+ using Base. Threads
2
3
3
4
abstract type BuildTargets end
4
5
struct JuliaTarget <: BuildTargets end
@@ -8,9 +9,10 @@ struct MATLABTarget <: BuildTargets end
8
9
9
10
abstract type ParallelForm end
10
11
struct SerialForm <: ParallelForm end
11
- struct MultithreadedForm <: ParallelForm end
12
- struct DistributedForm <: ParallelForm end
13
- struct DaggerForm <: ParallelForm end
12
+ struct MultithreadedForm <: ParallelForm
13
+ ntasks:: Int
14
+ end
15
+ MultithreadedForm () = MultithreadedForm (2 * nthreads ())
14
16
15
17
"""
16
18
`build_function`
@@ -60,18 +62,6 @@ function unflatten_long_ops(op, N=4)
60
62
Rewriters. Fixpoint (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2])))(op)
61
63
end
62
64
63
- function observed_let (eqs)
64
- process -> ex -> begin
65
- isempty (eqs) && return ex
66
-
67
- assignments = map (eq -> :($ (process (eq. lhs)) = $ (process (eq. rhs))), eqs)
68
- letexpr = :(let $ (assignments... )
69
- end )
70
- # avoid a superfluous `begin ... end` block
71
- letexpr. args[2 ] = ex
72
- return letexpr
73
- end
74
- end
75
65
76
66
# Scalar output
77
67
@@ -140,7 +130,7 @@ Build function target: JuliaTarget
140
130
function _build_function(target::JuliaTarget, rhss, args...;
141
131
conv = toexpr, expression = Val{true},
142
132
checkbounds = false,
143
- linenumbers = false, multithread=nothing,
133
+ linenumbers = false,
144
134
headerfun = addheader, outputidxs=nothing,
145
135
convert_oop = true, force_SA = false,
146
136
skipzeros = outputidxs===nothing,
@@ -168,10 +158,6 @@ Special Keyword Argumnets:
168
158
- `SerialForm()`: Serial execution.
169
159
- `MultithreadedForm()`: Multithreaded execution with a static split, evenly
170
160
splitting the number of expressions per thread.
171
- - `DistributedForm()`: Multiprocessing using Julia's Distributed with a static
172
- schedule, evenly splitting the number of expressions per process.
173
- - `DaggerForm()`: Multithreading and multiprocessing using Julia's Dagger.jl
174
- for dynamic scheduling and load balancing.
175
161
- `conv`: The conversion function of the Operation to Expr. By default this uses
176
162
the `toexpr` function.
177
163
- `checkbounds`: For whether to enable bounds checking inside of the generated
@@ -194,7 +180,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
194
180
expression = Val{true },
195
181
expression_module = @__MODULE__ (),
196
182
checkbounds = false ,
197
- linenumbers = false , multithread = nothing ,
183
+ linenumbers = false ,
198
184
outputidxs= nothing ,
199
185
skipzeros = false ,
200
186
wrap_code = (nothing , nothing ),
@@ -204,13 +190,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
204
190
dargs = map (destructure_arg, [args... ])
205
191
i = findfirst (x-> x isa DestructuredArgs, dargs)
206
192
similarto = i === nothing ? Array : dargs[i]. name
207
- oop_expr = Func (dargs, [], _make_array ( rhss, similarto))
193
+ oop_expr = Func (dargs, [], make_array (parallel, rhss, similarto))
208
194
if ! isnothing (wrap_code[1 ])
209
195
oop_expr = wrap_code[1 ](oop_expr)
210
196
end
211
197
212
198
out = Sym {Any} (gensym (" out" ))
213
- ip_expr = Func ([out, dargs... ], [], _set_array ( out, outputidxs, rhss, checkbounds, skipzeros))
199
+ ip_expr = Func ([out, dargs... ], [], set_array (parallel, out, outputidxs, rhss, checkbounds, skipzeros))
214
200
215
201
if ! isnothing (wrap_code[2 ])
216
202
ip_expr = wrap_code[2 ](ip_expr)
@@ -224,6 +210,24 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
224
210
end
225
211
end
226
212
213
+ function make_array (s, arr, similarto)
214
+ Base. @warn (" Parallel form of $(typeof (s)) not implemented" )
215
+ _make_array (arr, similarto)
216
+ end
217
+
218
+ function make_array (s:: SerialForm , arr, similarto)
219
+ _make_array (arr, similarto)
220
+ end
221
+
222
+ function make_array (s:: MultithreadedForm , arr, similarto)
223
+ per_task = ceil (Int, length (arr) / s. ntasks)
224
+ slices = collect (Iterators. partition (arr, per_task))
225
+ arrays = map (slices) do slice
226
+ _make_array (slice, similarto)
227
+ end
228
+ SpawnFetch {Multithreaded} (arrays, vcat)
229
+ end
230
+
227
231
function _make_array (rhss:: AbstractSparseArray , similarto)
228
232
arr = map (x-> _make_array (x, similarto), rhss)
229
233
if ! (arr isa AbstractSparseArray)
@@ -233,19 +237,59 @@ function _make_array(rhss::AbstractSparseArray, similarto)
233
237
end
234
238
end
235
239
240
+ function _make_array (rhss:: AbstractArray , similarto)
241
+ arr = map (x-> _make_array (x, similarto), rhss)
242
+ # Ugh reshaped array of a sparse array when mapped gives a sparse array
243
+ if arr isa AbstractSparseArray
244
+ _make_array (arr, similarto)
245
+ else
246
+ MakeArray (arr, similarto)
247
+ end
248
+ end
249
+
250
+ _make_array (x, similarto) = x
251
+
236
252
# # In-place version
237
- function _set_array (out, outputidxs, rhss:: AbstractArray , checkbounds, skipzeros)
238
- if rhss isa Union{SparseVector, SparseMatrixCSC}
239
- return SetArray (checkbounds, LiteralExpr (:($ out. nzval)), rhss. nzval)
240
- elseif outputidxs === nothing
253
+
254
+ function set_array (p, args... )
255
+ Base. @warn (" Parallel form of $(typeof (p)) not implemented" )
256
+ _set_array (args... )
257
+ end
258
+
259
+ function set_array (s:: SerialForm , args... )
260
+ _set_array (args... )
261
+ end
262
+
263
+ function set_array (s:: MultithreadedForm , out, outputidxs, rhss, checkbounds, skipzeros)
264
+ if rhss isa AbstractSparseArray
265
+ return set_array (LiteralExpr (:($ out. nzval)),
266
+ nothing ,
267
+ rhss. nzval,
268
+ checkbounds,
269
+ skipzeros)
270
+ end
271
+ if outputidxs === nothing
241
272
outputidxs = collect (eachindex (rhss))
242
273
end
274
+ per_task = ceil (Int, length (rhss) / s. ntasks)
275
+ # TODO : do better partitioning when skipzeros is present
276
+ slices = collect (Iterators. partition (zip (outputidxs, rhss), per_task))
277
+ arrays = map (slices) do slice
278
+ idxs, vals = first .(slice), last .(slice)
279
+ _set_array (out, idxs, vals, checkbounds, skipzeros)
280
+ end
281
+ SpawnFetch {Multithreaded} (arrays, @inline noop (args... ) = nothing )
282
+ end
243
283
284
+ function _set_array (out, outputidxs, rhss:: AbstractArray , checkbounds, skipzeros)
285
+ if outputidxs === nothing
286
+ outputidxs = collect (eachindex (rhss))
287
+ end
244
288
# sometimes outputidxs is a Tuple
245
289
ii = findall (i-> ! (rhss[i] isa AbstractArray) && ! (skipzeros && _iszero (rhss[i])), eachindex (outputidxs))
246
290
jj = findall (i-> rhss[i] isa AbstractArray, eachindex (outputidxs))
247
291
exprs = []
248
- push! (exprs, SetArray (checkbounds, out, AtIndex .(vec (collect (outputidxs[ii])), vec (rhss[ii]))))
292
+ push! (exprs, SetArray (! checkbounds, out, AtIndex .(vec (collect (outputidxs[ii])), vec (rhss[ii]))))
249
293
for j in jj
250
294
push! (exprs, _set_array (LiteralExpr (:($ out[$ j])), nothing , rhss[j], checkbounds, skipzeros))
251
295
end
257
301
_set_array (out, outputidxs, rhs, checkbounds, skipzeros) = rhs
258
302
259
303
260
- function _make_array (rhss:: AbstractArray , similarto)
261
- arr = map (x-> _make_array (x, similarto), rhss)
262
- # Ugh reshaped array of a sparse array when mapped gives a sparse array
263
- if arr isa AbstractSparseArray
264
- _make_array (arr, similarto)
265
- else
266
- MakeArray (arr, similarto)
267
- end
268
- end
269
-
270
- _make_array (x, similarto) = x
271
-
272
304
function vars_to_pairs (name,vs:: Union{Tuple, AbstractArray} , symsdict= Dict ())
273
305
vs_names = tosymbol .(vs)
274
306
for (v,k) in zip (vs_names, vs)
0 commit comments