Skip to content

Commit 94e38d7

Browse files
committed
Add a ParallelFinalReduction option for slow reducing operators.
1 parent 2affc74 commit 94e38d7

File tree

4 files changed

+121
-32
lines changed

4 files changed

+121
-32
lines changed

src/OhMyThreads.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ include("macros.jl")
2525
include("tools.jl")
2626
include("schedulers.jl")
2727
using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler,
28-
SerialScheduler
28+
SerialScheduler, FinalReductionMode, SerialFinalReduction,
29+
ParallelFinalReduction
2930
include("implementation.jl")
3031
include("experimental.jl")
3132

3233
export @tasks, @set, @local, @one_by_one, @only_one, @allow_boxed_captures, @disallow_boxed_captures, @localize
3334
export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect
3435
export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SerialScheduler
36+
export FinalReductionMode, SerialFinalReduction, ParallelFinalReduction
3537

3638
end # module OhMyThreads

src/implementation.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ using OhMyThreads.Schedulers: chunking_enabled,
1010
nchunks, chunksize, chunksplit, minchunksize, has_chunksplit,
1111
chunking_mode, ChunkingMode, NoChunking,
1212
FixedSize, FixedCount, scheduler_from_symbol, NotGiven,
13-
isgiven
13+
isgiven,
14+
FinalReductionMode,
15+
SerialFinalReduction, ParallelFinalReduction
1416
using Base: @propagate_inbounds
1517
using Base.Threads: nthreads, @threads
1618
using BangBang: append!!
@@ -86,7 +88,6 @@ function has_multiple_chunks(scheduler, coll)
8688
end
8789
end
8890

89-
9091
function tmapreduce(f, op, Arrs...;
9192
scheduler::MaybeScheduler = NotGiven(),
9293
outputtype::Type = Any,
@@ -115,6 +116,35 @@ end
115116
treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...)
116117

117118

119+
function tree_mapreduce(f, op, v)
120+
if length(v) == 1
121+
f(only(v))
122+
elseif length(v) == 2
123+
op(f(v[1]), f(v[2]))
124+
else
125+
l, r = v[begin:(end-begin)÷2], v[((end-begin)÷2+1):end]
126+
task_r = @spawn tree_mapreduce(f, op, r)
127+
result_l = tree_mapreduce(f, op, l)
128+
op(result_l, fetch(task_r))
129+
end
130+
end
131+
132+
function final_mapreduce(op, tasks, ::SerialFinalReduction; mapreduce_kwargs...)
133+
# Note, calling `promise_task_local` here is only safe because we're assuming that
134+
# Base.mapreduce isn't going to magically try to do multithreading on us...
135+
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
136+
end
137+
function final_mapreduce(op, tasks, ::ParallelFinalReduction; mapreduce_kwargs...)
138+
if isempty(tasks)
139+
# Note, calling `promise_task_local` here is only safe because we're assuming that
140+
# Base.mapreduce isn't going to magically try to do multithreading on us...
141+
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
142+
else
143+
tree_mapreduce(fetch, op, tasks; mapreduce_kwargs...)
144+
end
145+
end
146+
147+
118148
# DynamicScheduler: AbstractArray/Generic
119149
function _tmapreduce(f,
120150
op,
@@ -134,14 +164,13 @@ function _tmapreduce(f,
134164
@spawn threadpool mapreduce(promise_task_local(f), promise_task_local(op),
135165
args...; $mapreduce_kwargs...)
136166
end
137-
mapreduce(fetch, promise_task_local(op), tasks)
138167
else
139168
tasks = map(eachindex(first(Arrs))) do i
140169
args = map(A -> @inbounds(A[i]), Arrs)
141170
@spawn threadpool promise_task_local(f)(args...)
142171
end
143-
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
144172
end
173+
final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
145174
end
146175

147176
# DynamicScheduler: AbstractChunks
@@ -156,7 +185,7 @@ function _tmapreduce(f,
156185
tasks = map(only(Arrs)) do idcs
157186
@spawn threadpool promise_task_local(f)(idcs)
158187
end
159-
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
188+
final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
160189
end
161190

162191
# StaticScheduler: AbstractArray/Generic
@@ -284,9 +313,7 @@ function _tmapreduce(f,
284313
true
285314
end
286315
end
287-
# Note, calling `promise_task_local` here is only safe because we're assuming that
288-
# Base.mapreduce isn't going to magically try to do multithreading on us...
289-
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
316+
final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
290317
end
291318

292319
# GreedyScheduler w/ chunking
@@ -332,9 +359,7 @@ function _tmapreduce(f,
332359
true
333360
end
334361
end
335-
# Note, calling `promise_task_local` here is only safe because we're assuming that
336-
# Base.mapreduce isn't going to magically try to do multithreading on us...
337-
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
362+
final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
338363
end
339364

340365
function check_all_have_same_indices(Arrs)

src/schedulers.jl

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,57 @@ kind of scheduler.
170170
function default_nchunks end
171171
default_nchunks(::Type{<:Scheduler}) = nthreads(:default)
172172

173+
174+
175+
tree_str = raw"""
176+
```
177+
t1 t2 t3 t4 t5 t6
178+
\ | | / | /
179+
\ | | / | /
180+
op op op
181+
\ / /
182+
\ / /
183+
op /
184+
\ /
185+
op
186+
```
187+
"""
188+
189+
"""
190+
FinalReductionMode
191+
192+
A trait type to decide how the final reduction is performed. Essentially,
193+
OhMyThreads.jl will turn a `tmapreduce(f, op, v)` call into something of
194+
the form
195+
```julia
196+
tasks = map(chunks(v; chunking_kwargs...)) do chunk
197+
@spawn mapreduce(f, op, chunk)
198+
end
199+
final_reduction(op, tasks, ReductionMode)
200+
```
201+
where the options for `ReductionMode` are currently
202+
203+
* `SerialFinalReduction` is the default option that should be preferred whenever `op` is not the bottleneck in your reduction. In this mode, we use a simple `mapreduce` over the tasks vector, fetching each one, i.e.
204+
```julia
205+
function final_reduction(op, tasks, ::SerialFinalReduction)
206+
mapreduce(fetch, op, tasks)
207+
end
208+
```
209+
210+
* `ParallelFinalReduction` should be opted into when `op` takes a long time relative to the time it takes to `@spawn` and `fetch` tasks (typically tens of microseconds). In this mode, the vector of tasks is split up and `op` is applied in parallel using a recursive tree-based approach.
211+
$tree_str
212+
"""
213+
abstract type FinalReductionMode end
214+
struct SerialFinalReduction <: FinalReductionMode end
215+
struct ParallelFinalReduction <: FinalReductionMode end
216+
217+
FinalReductionMode(s::Scheduler) = s.final_reduction_mode
218+
219+
FinalReductionMode(s::Symbol) = FinalReductionMode(Val(s))
220+
FinalReductionMode(::Val{:serial}) = SerialFinalReduction()
221+
FinalReductionMode(::Val{:parallel}) = ParallelFinalReduction()
222+
FinalReductionMode(m::FinalReductionMode) = m
223+
173224
"""
174225
DynamicScheduler (aka :dynamic)
175226
@@ -202,16 +253,17 @@ with other multithreaded code.
202253
- `threadpool::Symbol` (default `:default`):
203254
* Possible options are `:default` and `:interactive`.
204255
* The high-priority pool `:interactive` should be used very carefully since tasks on this threadpool should not be allowed to run for a long time without `yield`ing as it can interfere with [heartbeat](https://en.wikipedia.org/wiki/Heartbeat_(computing)) processes.
256+
- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds).
205257
"""
206-
struct DynamicScheduler{C <: ChunkingMode, S <: Split} <: Scheduler
258+
struct DynamicScheduler{C <: ChunkingMode, S <: Split, FRM <: FinalReductionMode} <: Scheduler
207259
threadpool::Symbol
208260
chunking_args::ChunkingArgs{C, S}
209-
210-
function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs)
261+
final_reduction_mode::FRM
262+
function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs, frm=SerialFinalReduction())
211263
if !(threadpool in (:default, :interactive))
212264
throw(ArgumentError("threadpool must be either :default or :interactive"))
213265
end
214-
new{chunking_mode(ca), typeof(ca.split)}(threadpool, ca)
266+
new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(threadpool, ca, frm)
215267
end
216268
end
217269

@@ -222,15 +274,17 @@ function DynamicScheduler(;
222274
chunksize::MaybeInteger = NotGiven(),
223275
chunking::Bool = true,
224276
split::Union{Split, Symbol} = Consecutive(),
225-
minchunksize::Union{Nothing, Int}=nothing)
277+
minchunksize::Union{Nothing, Int}=nothing,
278+
final_reduction_mode::Union{Symbol, FinalReductionMode}=SerialFinalReduction())
226279
if isgiven(ntasks)
227280
if isgiven(nchunks)
228281
throw(ArgumentError("For the dynamic scheduler, nchunks and ntasks are aliases and only one may be provided"))
229282
end
230283
nchunks = ntasks
231284
end
232285
ca = ChunkingArgs(DynamicScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize)
233-
return DynamicScheduler(threadpool, ca)
286+
frm = FinalReductionMode(final_reduction_mode)
287+
return DynamicScheduler(threadpool, ca, frm)
234288
end
235289
from_symbol(::Val{:dynamic}) = DynamicScheduler
236290
chunking_args(sched::DynamicScheduler) = sched.chunking_args
@@ -239,7 +293,8 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::DynamicScheduler
239293
print(io, "DynamicScheduler", "\n")
240294
cstr = _chunkingstr(s.chunking_args)
241295
println(io, "├ Chunking: ", cstr)
242-
print(io, "└ Threadpool: ", s.threadpool)
296+
println(io, "├ Threadpool: ", s.threadpool)
297+
print(io, "└ FinalReductionMode: ", FinalReductionMode(s))
243298
end
244299

245300
"""
@@ -336,14 +391,16 @@ some additional overhead.
336391
- `split::Union{Symbol, OhMyThreads.Split}` (default `OhMyThreads.RoundRobin()`):
337392
* Determines how the collection is divided into chunks (if chunking=true).
338393
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. We also allow users to pass `:consecutive` in place of `Consecutive()`, and `:roundrobin` in place of `RoundRobin()`
394+
- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds).
339395
"""
340-
struct GreedyScheduler{C <: ChunkingMode, S <: Split} <: Scheduler
396+
struct GreedyScheduler{C <: ChunkingMode, S <: Split, FRM <:FinalReductionMode} <: Scheduler
341397
ntasks::Int
342398
chunking_args::ChunkingArgs{C, S}
399+
final_reduction_mode::FRM
343400

344-
function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs)
401+
function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs, frm=SerialFinalReduction())
345402
ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer"))
346-
return new{chunking_mode(ca), typeof(ca.split)}(ntasks, ca)
403+
return new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(ntasks, ca, frm)
347404
end
348405
end
349406

@@ -353,12 +410,14 @@ function GreedyScheduler(;
353410
chunksize::MaybeInteger = NotGiven(),
354411
chunking::Bool = false,
355412
split::Union{Split, Symbol} = RoundRobin(),
356-
minchunksize::Union{Nothing, Int} = nothing)
413+
minchunksize::Union{Nothing, Int} = nothing,
414+
final_reduction_mode::Union{Symbol,FinalReductionMode} = SerialFinalReduction())
357415
if isgiven(nchunks) || isgiven(chunksize)
358416
chunking = true
359417
end
360418
ca = ChunkingArgs(GreedyScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize)
361-
return GreedyScheduler(ntasks, ca)
419+
frm = FinalReductionMode(final_reduction_mode)
420+
return GreedyScheduler(ntasks, ca, frm)
362421
end
363422
from_symbol(::Val{:greedy}) = GreedyScheduler
364423
chunking_args(sched::GreedyScheduler) = sched.chunking_args
@@ -367,9 +426,9 @@ default_nchunks(::Type{GreedyScheduler}) = 10 * nthreads(:default)
367426
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::GreedyScheduler)
368427
print(io, "GreedyScheduler", "\n")
369428
println(io, "├ Num. tasks: ", s.ntasks)
370-
cstr = _chunkingstr(s)
371-
println(io, "Chunking: ", cstr)
372-
print(io, "Threadpool: default")
429+
println(io, "├ Chunking: ", _chunkingstr(s))
430+
println(io, "Threadpool: default")
431+
print( io, "FinalReductionMode: ", FinalReductionMode(s))
373432
end
374433

375434
"""

test/runtests.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -740,14 +740,16 @@ end
740740
"""
741741
DynamicScheduler
742742
├ Chunking: fixed count ($nt), split :consecutive
743-
└ Threadpool: default"""
743+
├ Threadpool: default
744+
└ FinalReductionMode: SerialFinalReduction()"""
744745

745746
@test repr(
746-
"text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive)) ==
747+
"text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive, final_reduction_mode=:parallel)) ==
747748
"""
748749
DynamicScheduler
749750
├ Chunking: none
750-
└ Threadpool: interactive"""
751+
├ Threadpool: interactive
752+
└ FinalReductionMode: ParallelFinalReduction()"""
751753

752754
@test repr("text/plain", StaticScheduler()) ==
753755
"""StaticScheduler
@@ -764,8 +766,9 @@ end
764766
"""
765767
GreedyScheduler
766768
├ Num. tasks: $nt
767-
├ Chunking: fixed count ($(10 * nt)), split :roundrobin
768-
└ Threadpool: default"""
769+
├ Chunking: fixed count ($(10*nt)), split :roundrobin
770+
├ Threadpool: default
771+
└ FinalReductionMode: SerialFinalReduction()"""
769772
end
770773

771774
@testset "Boxing detection and error" begin

0 commit comments

Comments
 (0)