diff --git a/src/dtask.jl b/src/dtask.jl index b7477428..8ae79f63 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -1,19 +1,64 @@ export DTask +""" + LocalFuture + +A fast, shared-memory alternative to Distributed's Future. +""" +mutable struct LocalFuture + const ready::Base.Event + errored::Bool + value::Union{Some{Any}, Nothing} + + LocalFuture() = new(Base.Event(), false, nothing) +end +Base.isready(f::LocalFuture) = f.ready.set # FIXME: Use isready(f.ready) +function Base.wait(f::LocalFuture) + wait(f.ready) + return +end + "A future holding the result of a `Thunk`." -struct ThunkFuture - future::Future +mutable struct ThunkFuture + const from::Int + local_future::Union{LocalFuture, Nothing} + remote_future::Union{Future, Nothing} +end +function ThunkFuture(from::Int=myid()) + if from == myid() + return ThunkFuture(from, LocalFuture(), nothing) + else + return ThunkFuture(from, nothing, Future()) + end +end +function Base.isready(t::ThunkFuture) + if t.local_future !== nothing + return isready(t.local_future::LocalFuture) + else + return isready(t.remote_future::Future)::Bool + end end -ThunkFuture(x::Integer) = ThunkFuture(Future(x)) -ThunkFuture() = ThunkFuture(Future()) -Base.isready(t::ThunkFuture) = isready(t.future) Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do - wait(t.future) + if t.from == myid() + wait(t.local_future) + else + wait(t.remote_future) + end return end function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) - error, value = Dagger.Sch.thunk_yield() do - fetch(t.future) + if t.from == myid() + if !isready(t.local_future) + Dagger.Sch.thunk_yield() do + wait(t.local_future) + end + end + value = something(t.local_future.value) + error = t.local_future.errored + else + error, value = Dagger.Sch.thunk_yield() do + fetch(t.remote_future) + end end if error throw(value) @@ -24,7 +69,43 @@ function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) return move(proc, value) end end -Base.put!(t::ThunkFuture, x; error=false) = put!(t.future, (error, x)) +function Base.put!(t::ThunkFuture, x; error=false) + if isready(t) + throw(ConcurrencyViolationError("ThunkFuture can't be set twice")) + end + + # Notify either or both futures + if t.local_future !== nothing + t.local_future.value = Some{Any}(x) + t.local_future.errored = error + notify(t.local_future.ready) + end + if t.remote_future !== nothing + put!(t.remote_future, (error, x)) + end + + return x +end +function Serialization.serialize(io::AbstractSerializer, t::ThunkFuture) + if t.remote_future === nothing + # Add a Future + t.remote_future = Future() + end + + # Serialize normally + return invoke(serialize, Tuple{typeof(io), Any}, io, t) +end +function Serialization.deserialize(io::AbstractSerializer, ::Type{ThunkFuture}) + # Deserialize normally + t = invoke(deserialize, Tuple{AbstractSerializer, DataType}, io, ThunkFuture) + + if t.local_future !== nothing + # Remove the (now useless) LocalFuture + t.local_future = nothing + end + + return t +end """ DTaskMetadata @@ -45,9 +126,9 @@ executing. May be `fetch`'d or `wait`'d on at any time. See `Dagger.@spawn` for more details. """ mutable struct DTask - uid::UInt + const uid::UInt future::ThunkFuture - metadata::DTaskMetadata + const metadata::DTaskMetadata thunk_ref::DRef DTask(uid, future, metadata) = new(uid, future, metadata) @@ -55,14 +136,16 @@ end const EagerThunk = DTask -Base.isready(t::DTask) = isready(t.future) +Base.isready(t::DTask) = isready(t.future)::Bool Base.istaskdone(t::DTask) = isready(t.future) Base.istaskstarted(t::DTask) = isdefined(t, :thunk_ref) function Base.wait(t::DTask) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `wait` on an unlaunched `DTask`")) end - wait(t.future) + if !isready(t) + wait(t.future) + end return end function Base.fetch(t::DTask; raw=false) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index f7ee904d..61ef9804 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1101,8 +1101,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - if !isa(constrain(scope, Dagger.ExactScope(to_proc)), - InvalidScope) && + if Dagger.proc_in_scope(to_proc, scope) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task return dequeue_pair!(queue) diff --git a/src/sch/util.jl b/src/sch/util.jl index 1d947c23..9141f9a3 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -447,8 +447,7 @@ function can_use_proc(state, task, gproc, proc, opts, scope) end # Check against scope - proc_scope = Dagger.ExactScope(proc) - if constrain(scope, proc_scope) isa Dagger.InvalidScope + if !Dagger.proc_in_scope(proc, scope) @dagdebug task :scope "Rejected $proc: Not contained in task scope ($scope)" return false, scope end diff --git a/src/scopes.jl b/src/scopes.jl index ecb3e5e0..ba291bc2 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -4,8 +4,11 @@ abstract type AbstractScope end "Widest scope that contains all processors." struct AnyScope <: AbstractScope end +proc_in_scope(::Processor, ::AnyScope) = true abstract type AbstractScopeTaint end +proc_in_scope(proc::Processor, scope::AbstractScope) = + !isa(constrain(scope, ExactScope(proc)), InvalidScope) "Taints a scope for later evaluation." struct TaintScope <: AbstractScope @@ -44,6 +47,8 @@ UnionScope(scopes...) = UnionScope((scopes...,)) UnionScope(scopes::Vector{<:AbstractScope}) = UnionScope((scopes...,)) UnionScope(s::AbstractScope) = UnionScope((s,)) UnionScope() = UnionScope(()) +proc_in_scope(proc::Processor, scope::UnionScope) = + any(subscope->proc_in_scope(proc, subscope), scope.scopes) function Base.:(==)(us1::UnionScope, us2::UnionScope) if length(us1.scopes) != length(us2.scopes) @@ -78,6 +83,8 @@ function ProcessScope(wid::Integer) end ProcessScope(p::OSProc) = ProcessScope(p.pid) ProcessScope() = ProcessScope(myid()) +proc_in_scope(proc::Processor, scope::ProcessScope) = + root_worker_id(proc) == scope.wid struct ProcessorTypeTaint{T} <: AbstractScopeTaint end @@ -92,12 +99,14 @@ struct ExactScope <: AbstractScope processor::Processor end ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +proc_in_scope(proc::Processor, scope::ExactScope) = proc == scope.processor "Indicates that the applied scopes `x` and `y` are incompatible." struct InvalidScope <: AbstractScope x::AbstractScope y::AbstractScope end +proc_in_scope(::Processor, ::InvalidScope) = false # Show methods