Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using OhMyThreads.Schedulers: chunksplitter_mode, chunking_enabled,
has_minchunksize, chunkingargs_to_kwargs,
chunking_mode, ChunkingMode, NoChunking,
FixedSize, FixedCount, scheduler_from_symbol, NotGiven,
isgiven, threadpool as get_threadpool
isgiven, threadpool as get_threadpool, allow_migration
using Base: @propagate_inbounds
using Base.Threads: nthreads, @threads
using BangBang: append!!
Expand Down Expand Up @@ -255,11 +255,23 @@ function _tmapreduce(f,
put!(ch, args)
end
end
tasks = map(1:ntasks) do _
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawn mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do args
promise_task_local(f)(args...)
if allow_migration(scheduler)
tasks = map(1:ntasks) do _
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawn mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do args
promise_task_local(f)(args...)
end
end
else
nt = nthreads()
tasks = map(1:ntasks) do c
tid = @inbounds nthtid(mod1(c, nt))
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawnat tid mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do args
promise_task_local(f)(args...)
end
end
end
# Doing this because of https://github.com/JuliaFolds2/OhMyThreads.jl/issues/82
Expand Down Expand Up @@ -302,12 +314,25 @@ function _tmapreduce(f,
# ChunkSplitters.IndexChunks support everything needed for ChannelLike
ch = ChannelLike(chnks)

tasks = map(1:ntasks) do _
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawn mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do inds
args = map(A -> view(A, inds), Arrs)
mapreduce(promise_task_local(f), promise_task_local(op), args...)
if allow_migration(scheduler)
tasks = map(1:ntasks) do _
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawn mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do inds
args = map(A -> view(A, inds), Arrs)
mapreduce(promise_task_local(f), promise_task_local(op), args...)
end
end
else
nt = nthreads()
tasks = map(1:ntasks) do c
tid = @inbounds nthtid(mod1(c, nt))
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawnat tid mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do inds
args = map(A -> view(A, inds), Arrs)
mapreduce(promise_task_local(f), promise_task_local(op), args...)
end
end
end
# Doing this because of https://github.com/JuliaFolds2/OhMyThreads.jl/issues/82
Expand Down
26 changes: 18 additions & 8 deletions src/schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ kind of scheduler.
function default_nchunks end
default_nchunks(::Type{<:Scheduler}) = nthreads(:default)

allow_migration(::Scheduler) = false

"""
DynamicScheduler (aka :dynamic)

Expand Down Expand Up @@ -216,7 +218,7 @@ function DynamicScheduler(;
chunksize = nothing,
split::Union{Split, Symbol} = Consecutive(),
minchunksize = nothing,
chunking::Bool = true
chunking::Bool = true,
)
if !isnothing(ntasks)
if !isnothing(nchunks)
Expand All @@ -231,6 +233,7 @@ end
from_symbol(::Val{:dynamic}) = DynamicScheduler
chunking_args(sched::DynamicScheduler) = sched.chunking_args
threadpool(::DynamicScheduler{C, S, T}) where {C, S, T} = T
allow_migration(::DynamicScheduler) = true

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::DynamicScheduler)
print(io, "DynamicScheduler", "\n")
Expand Down Expand Up @@ -305,9 +308,9 @@ end
"""
GreedyScheduler (aka :greedy)

A greedy dynamic scheduler. The elements are put into a shared workqueue and dynamic,
non-sticky, tasks are spawned to process the elements of the queue with each task taking a new
element from the queue as soon as the previous one is done.
A greedy dynamic scheduler. The elements are put into a shared workqueue and tasks (by
default dynamic, non-sticky) are spawned to process the elements of the queue with each
task taking a new element from the queue as soon as the previous one is done.

Note that elements are processed in a non-deterministic order, and thus a potential reducing
function **must** be [commutative](https://en.wikipedia.org/wiki/Commutative_property) in
Expand Down Expand Up @@ -335,14 +338,18 @@ some additional overhead.
- `split::Union{Symbol, OhMyThreads.Split}` (default `OhMyThreads.RoundRobin()`):
* Determines how the collection is divided into chunks (if chunking=true).
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. We also allow users to pass `:consecutive` in place of `Consecutive()`, and `:roundrobin` in place of `RoundRobin()`
- `migration::Bool` (default `true`):
* Controls whether tasks are allowed to migrate between threads (`true`) or not (`false`).
* For `migration=false`, a task will always be executed by the same thread; however, this does not mean that the thread will always be executed on the same CPU. See the `ThreadPinning.jl` package for this.
"""
struct GreedyScheduler{C <: ChunkingMode, S <: Split} <: Scheduler
ntasks::Int
chunking_args::ChunkingArgs{C, S}
migration::Bool

function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs)
function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs, migration::Bool)
ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer"))
return new{chunking_mode(ca), typeof(ca.split)}(ntasks, ca)
return new{chunking_mode(ca), typeof(ca.split)}(ntasks, ca, migration)
end
end

Expand All @@ -352,24 +359,27 @@ function GreedyScheduler(;
chunksize = nothing,
minchunksize = nothing,
split::Union{Split, Symbol} = RoundRobin(),
chunking::Bool = false
chunking::Bool = false,
migration::Bool = true
)
if !(isnothing(nchunks) && isnothing(chunksize))
chunking = true
end
ca = ChunkingArgs(GreedyScheduler;
n = nchunks, size = chunksize, minsize = minchunksize, split, chunking)
return GreedyScheduler(ntasks, ca)
return GreedyScheduler(ntasks, ca, migration)
end
from_symbol(::Val{:greedy}) = GreedyScheduler
chunking_args(sched::GreedyScheduler) = sched.chunking_args
default_nchunks(::Type{GreedyScheduler}) = 10 * nthreads(:default)
allow_migration(sched::GreedyScheduler) = sched.migration

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::GreedyScheduler)
print(io, "GreedyScheduler", "\n")
println(io, "├ Num. tasks: ", s.ntasks)
cstr = _chunkingstr(s)
println(io, "├ Chunking: ", cstr)
println(io, "├ Task migration: ", allow_migration(s))
print(io, "└ Threadpool: default")
end

Expand Down
17 changes: 14 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...)
for (; ~, f, op, itrs, init) in sets_to_test
@testset "f=$f, op=$op, itrs::$(typeof(itrs))" begin
@testset for sched in (
StaticScheduler, DynamicScheduler, GreedyScheduler,
StaticScheduler, DynamicScheduler, GreedyScheduler, :GreedySchedulerNoMig,
DynamicScheduler{OhMyThreads.Schedulers.NoChunking},
SerialScheduler, ChunkedGreedy)
@testset for split in (Consecutive(), RoundRobin(), :consecutive, :roundrobin)
for nchunks in (1, 2, 6)
for minchunksize ∈ (nothing, 1, 3)
if sched == GreedyScheduler
scheduler = sched(; ntasks = nchunks, minchunksize)
elseif sched == :GreedySchedulerNoMig
scheduler = GreedyScheduler(; ntasks = nchunks, minchunksize, migration = false)
elseif sched == DynamicScheduler{OhMyThreads.Schedulers.NoChunking}
scheduler = DynamicScheduler(; chunking = false)
elseif sched == SerialScheduler
Expand All @@ -40,7 +42,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...)
end
kwargs = (; scheduler)
if (split in (RoundRobin(), :roundrobin) ||
sched ∈ (GreedyScheduler, ChunkedGreedy)) || op ∉ (vcat, *)
sched ∈ (GreedyScheduler, :GreedySchedulerNoMig, ChunkedGreedy)) || op ∉ (vcat, *)
# scatter and greedy only works for commutative operators!
else
mapreduce_f_op_itr = mapreduce(f, op, itrs...)
Expand All @@ -61,7 +63,7 @@ ChunkedGreedy(; kwargs...) = GreedyScheduler(; kwargs...)
@test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr

if sched ∉ (GreedyScheduler, ChunkedGreedy)
if sched ∉ (GreedyScheduler, :GreedySchedulerNoMig, ChunkedGreedy)
@test tmap(f, itrs...; kwargs...) ~ map_f_itr
@test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(f.(itrs...); kwargs...) ~ map_f_itr
Expand Down Expand Up @@ -760,6 +762,15 @@ end
GreedyScheduler
├ Num. tasks: $nt
├ Chunking: fixed count ($(10 * nt)), split :roundrobin
├ Task migration: true
└ Threadpool: default"""

@test repr("text/plain", GreedyScheduler(; chunking = true, migration = false)) ==
"""
GreedyScheduler
├ Num. tasks: $nt
├ Chunking: fixed count ($(10 * nt)), split :roundrobin
├ Task migration: false
└ Threadpool: default"""
end

Expand Down
Loading