Skip to content

Commit cca7173

Browse files
committed
store ntasks in MultithreadedForm
1 parent f6668a6 commit cca7173

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/build_function.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ struct MATLABTarget <: BuildTargets end
99

1010
abstract type ParallelForm end
1111
struct SerialForm <: ParallelForm end
12-
struct MultithreadedForm <: ParallelForm end
12+
struct MultithreadedForm <: ParallelForm
13+
ntasks::Int
14+
end
15+
MultithreadedForm() = MultithreadedForm(2*nthreads())
1316
struct DistributedForm <: ParallelForm end
1417
struct DaggerForm <: ParallelForm end
1518

@@ -223,8 +226,7 @@ function make_array(s::SerialForm, arr, similarto)
223226
end
224227

225228
function make_array(s::MultithreadedForm, arr, similarto)
226-
ntasks = 2 * nthreads() # oversubscribe a little bit
227-
per_task = ceil(Int, length(arr) / ntasks)
229+
per_task = ceil(Int, length(arr) / s.ntasks)
228230
slices = collect(Iterators.partition(arr, per_task))
229231
arrays = map(slices) do slice
230232
_make_array(slice, similarto)
@@ -275,8 +277,7 @@ function set_array(s::MultithreadedForm, out, outputidxs, rhss, checkbounds, ski
275277
if outputidxs === nothing
276278
outputidxs = collect(eachindex(rhss))
277279
end
278-
ntasks = 2 * nthreads() # oversubscribe a little bit
279-
per_task = ceil(Int, length(rhss) / ntasks)
280+
per_task = ceil(Int, length(rhss) / s.ntasks)
280281
# TODO: do better partitioning when skipzeros is present
281282
slices = collect(Iterators.partition(zip(outputidxs, rhss), per_task))
282283
arrays = map(slices) do slice

0 commit comments

Comments
 (0)