@@ -9,7 +9,10 @@ struct MATLABTarget <: BuildTargets end
9
9
10
10
abstract type ParallelForm end
11
11
struct SerialForm <: ParallelForm end
12
- struct MultithreadedForm <: ParallelForm end
12
+ struct MultithreadedForm <: ParallelForm
13
+ ntasks:: Int
14
+ end
15
+ MultithreadedForm () = MultithreadedForm (2 * nthreads ())
13
16
struct DistributedForm <: ParallelForm end
14
17
struct DaggerForm <: ParallelForm end
15
18
@@ -223,8 +226,7 @@ function make_array(s::SerialForm, arr, similarto)
223
226
end
224
227
225
228
function make_array (s:: MultithreadedForm , arr, similarto)
226
- ntasks = 2 * nthreads () # oversubscribe a little bit
227
- per_task = ceil (Int, length (arr) / ntasks)
229
+ per_task = ceil (Int, length (arr) / s. ntasks)
228
230
slices = collect (Iterators. partition (arr, per_task))
229
231
arrays = map (slices) do slice
230
232
_make_array (slice, similarto)
@@ -275,8 +277,7 @@ function set_array(s::MultithreadedForm, out, outputidxs, rhss, checkbounds, ski
275
277
if outputidxs === nothing
276
278
outputidxs = collect (eachindex (rhss))
277
279
end
278
- ntasks = 2 * nthreads () # oversubscribe a little bit
279
- per_task = ceil (Int, length (rhss) / ntasks)
280
+ per_task = ceil (Int, length (rhss) / s. ntasks)
280
281
# TODO : do better partitioning when skipzeros is present
281
282
slices = collect (Iterators. partition (zip (outputidxs, rhss), per_task))
282
283
arrays = map (slices) do slice
0 commit comments