diff --git a/docs/src/api-dagger/types.md b/docs/src/api-dagger/types.md index a36ee3a0f..4e6d0460e 100644 --- a/docs/src/api-dagger/types.md +++ b/docs/src/api-dagger/types.md @@ -16,7 +16,6 @@ DTask ## Task Options Types ```@docs Options -Sch.ThunkOptions Sch.SchedulerOptions ``` diff --git a/docs/src/checkpointing.md b/docs/src/checkpointing.md index 97dd8e8ed..fa798185d 100644 --- a/docs/src/checkpointing.md +++ b/docs/src/checkpointing.md @@ -71,7 +71,7 @@ z = collect(Z) ``` Two changes were made: first, we `enumerate(X.chunks)` so that we can get a -unique index to identify each `chunk`; second, we specify a `ThunkOptions` to +unique index to identify each `chunk`; second, we specify options to `delayed` with a `checkpoint` and `restore` function that is specialized to write or read the given chunk to or from a file on disk, respectively. Notice the usage of `collect` in the `checkpoint` function, and the use of diff --git a/docs/src/task-spawning.md b/docs/src/task-spawning.md index 53d7e28bc..d44a27ad8 100644 --- a/docs/src/task-spawning.md +++ b/docs/src/task-spawning.md @@ -12,9 +12,9 @@ or `spawn` if it's more convenient: `Dagger.spawn(f, Dagger.Options(options), args...; kwargs...)` -When called, it creates an [`DTask`](@ref) (also known as a "thunk" or -"task") object representing a call to function `f` with the arguments `args` and -keyword arguments `kwargs`. If it is called with other thunks as args/kwargs, +When called, it creates an [`DTask`](@ref) (also known as a "task" or +"thunk") object representing a call to function `f` with the arguments `args` and +keyword arguments `kwargs`. If it is called with other tasks as args/kwargs, such as in `Dagger.@spawn f(Dagger.@spawn g())`, then, in this example, the function `f` gets passed the results of executing `g()`, once that result is available. If `g()` isn't yet finished executing, then the execution of `f` @@ -29,9 +29,17 @@ it'll be passed as-is to the function `f` (with some exceptions). !!! note "Task / thread occupancy" By default, `Dagger` assumes that tasks saturate the thread they are running on and does not try to schedule other tasks on the thread. - This default can be controlled by specifying [`Sch.ThunkOptions`](@ref) (more details can be found under [Scheduler and Thunk options](@ref)). + This default can be controlled by specifying [`Options`](@ref) (more details can be found under [Task and Scheduler options](@ref)). The section [Changing the thread occupancy](@ref) shows a runnable example of how to achieve this. +## Options + +The [`Options`](@ref Dagger.Options) struct in the second argument position is +optional; if provided, it is passed to the scheduler to control its +behavior. [`Options`](@ref Dagger.Options) contains option +key-value pairs, which can be any field in [`Options`](@ref) +(see [Task and Scheduler options](@ref)). + ## Simple example Let's see a very simple directed acyclic graph (or DAG) constructed with Dagger: @@ -51,7 +59,7 @@ s = Dagger.@spawn combine(p, q, r) @assert fetch(s) == 16 ``` -The thunks `p`, `q`, `r`, and `s` have the following structure: +The tasks `p`, `q`, `r`, and `s` have the following structure: ![graph](https://user-images.githubusercontent.com/25916/26920104-7b9b5fa4-4c55-11e7-97fb-fe5b9e73cae6.png) @@ -108,7 +116,8 @@ x::DTask @assert fetch(x) == 3 # fetch the result of `@spawn` ``` -This is useful for nested execution, where an `@spawn`'d thunk calls `@spawn`. This is detailed further in [Dynamic Scheduler Control](@ref). +This is useful for nested execution, where an `@spawn`'d task calls `@spawn`. +This is detailed further in [Dynamic Scheduler Control](@ref). ## Options @@ -116,7 +125,7 @@ The [`Options`](@ref Dagger.Options) struct in the second argument position is optional; if provided, it is passed to the scheduler to control its behavior. [`Options`](@ref Dagger.Options) contains a `NamedTuple` of option key-value pairs, which can be any of: -- Any field in [`Sch.ThunkOptions`](@ref) (see [Scheduler and Thunk options](@ref)) +- Any field in [`Options`](@ref) (see [Task and Scheduler options](@ref)) - `meta::Bool` -- Pass the input [`Chunk`](@ref) objects themselves to `f` and not the value contained in them. @@ -127,19 +136,19 @@ There are also some extra options that can be passed, although they're considere ## Errors -If a thunk errors while running under the eager scheduler, it will be marked as -having failed, all dependent (downstream) thunks will be marked as failed, and -any future thunks that use a failed thunk as input will fail. Failure can be +If a task errors while running under the eager scheduler, it will be marked as +having failed, all dependent (downstream) tasks will be marked as failed, and +any future tasks that use a failed task as input will fail. Failure can be determined with `fetch`, which will re-throw the error that the -originally-failing thunk threw. `wait` and `isready` will *not* check whether a -thunk or its upstream failed; they only check if the thunk has completed, error +originally-failing task threw. `wait` and `isready` will *not* check whether a +task or its upstream failed; they only check if the task has completed, error or not. This failure behavior is not the default for lazy scheduling ([Lazy API](@ref)), -but can be enabled by setting the scheduler/thunk option ([Scheduler and Thunk options](@ref)) +but can be enabled by setting the scheduler/task option ([Task and Scheduler options](@ref)) `allow_error` to `true`. However, this option isn't terribly useful for -non-dynamic usecases, since any thunk failure will propagate down to the output -thunk regardless of where it occurs. +non-dynamic usecases, since any task failure will propagate down to the output +task regardless of where it occurs. ## Cancellation @@ -198,7 +207,7 @@ end ``` Alternatively, if you want to compute but not fetch the result of a lazy -operation, you can call `compute` on the thunk. This will return a `Chunk` +operation, you can call `compute` on the task. This will return a `Chunk` object which references the result (see [Chunks](@ref) for more details): ```julia @@ -215,16 +224,14 @@ Note that, as a legacy API, usage of the lazy API is generally discouraged for m - Distinct schedulers don't share runtime metrics or learned parameters, thus causing the scheduler to act less intelligently - Distinct schedulers can't share work or data directly -## Scheduler and Thunk options +## Task and Scheduler options While Dagger generally "just works", sometimes one needs to exert some more fine-grained control over how the scheduler allocates work. There are two -parallel mechanisms to achieve this: Scheduler options (from -[`Sch.SchedulerOptions`](@ref)) and Thunk options (from -[`Sch.ThunkOptions`](@ref)). These two options structs contain many shared -options, with the difference being that Scheduler options operate -globally across an entire DAG, and Thunk options operate on a thunk-by-thunk -basis. +parallel mechanisms to achieve this: Task options (from [`Options`](@ref)) and +Scheduler options (from [`Sch.SchedulerOptions`](@ref)). Scheduler +options operate globally across an entire DAG, and Task options operate on a +task-by-task basis. Scheduler options can be constructed and passed to `collect()` or `compute()` as the keyword argument `options` for lazy API usage: @@ -238,7 +245,7 @@ compute(t; options=opts) collect(t; options=opts) ``` -Thunk options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly: +Task options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly: ```julia # Execute on worker 1 @@ -251,8 +258,9 @@ delayed(+; single=1)(1, 2) ## Changing the thread occupancy -One of the supported [`Sch.ThunkOptions`](@ref) is the `occupancy` keyword. -This keyword can be used to communicate that a task is not expected to fully saturate a CPU core (e.g. due to being IO-bound). +One of the supported [`Options`](@ref) is the `occupancy` keyword. +This keyword can be used to communicate that a task is not expected to fully +saturate a CPU core (e.g. due to being IO-bound). The basic usage looks like this: ```julia diff --git a/ext/GraphVizExt.jl b/ext/GraphVizExt.jl index d701ade54..2a43773e4 100644 --- a/ext/GraphVizExt.jl +++ b/ext/GraphVizExt.jl @@ -21,6 +21,7 @@ Requires the `all_task_deps` event enabled in `enable_logging!` Options: - `disconnected`: If `true`, render disconnected vertices (tasks or arguments without upstream/downstream dependencies) +- `show_data`: If `true`, show the data dependencies in the graph - `color_by`: How to color tasks; if `:fn`, then color by unique function name, if `:proc`, then color by unique processor - `layout_engine`: The layout engine to use for GraphViz rendering - `times`: If `true`, annotate each task with its start and finish times @@ -28,12 +29,14 @@ Options: - `colors`: A list of colors to use for coloring tasks - `name_to_color`: A function that maps task names to colors """ -function Dagger.render_logs(logs::Dict, ::Val{:graphviz}; disconnected=false, +function Dagger.render_logs(logs::Dict, ::Val{:graphviz}; + disconnected=false, show_data::Bool=true, color_by=:fn, layout_engine="dot", times::Bool=true, times_digits::Integer=3, colors=Dagger.Viz.default_colors, name_to_color=Dagger.Viz.name_to_color) - dot = Dagger.Viz.logs_to_dot(logs; disconnected, times, times_digits, + dot = Dagger.Viz.logs_to_dot(logs; disconnected, show_data, + times, times_digits, color_by, colors, name_to_color) gv = GraphViz.Graph(dot) GraphViz.layout!(gv; engine=layout_engine) diff --git a/src/Dagger.jl b/src/Dagger.jl index 0c3761c44..c0cb23526 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -17,9 +17,9 @@ import Random: AbstractRNG import UUIDs: UUID, uuid4 if !isdefined(Base, :ScopedValues) - import ScopedValues: ScopedValue, with + import ScopedValues: ScopedValue, @with, with else - import Base.ScopedValues: ScopedValue, with + import Base.ScopedValues: ScopedValue, @with, with end import TaskLocalValues: TaskLocalValue @@ -32,7 +32,6 @@ import TimespanLogging: timespan_start, timespan_finish import Adapt -# Preferences import Preferences: @load_preference, @set_preferences! if @load_preference("distributed-package") == "DistributedNext" @@ -43,29 +42,35 @@ else import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch, check_same_host end +import MacroTools: @capture, prewalk + include("lib/util.jl") include("utils/dagdebug.jl") # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") - -import MacroTools: @capture, prewalk - -include("options.jl") +include("utils/reuse.jl") include("processor.jl") include("threadproc.jl") +include("sch_options.jl") include("context.jl") include("utils/processors.jl") +include("scopes.jl") +include("utils/scopes.jl") +include("chunks.jl") +include("utils/signature.jl") +include("options.jl") include("dtask.jl") include("cancellation.jl") include("task-tls.jl") -include("scopes.jl") -include("utils/scopes.jl") +include("argument.jl") include("queue.jl") include("thunk.jl") +include("utils/fetch.jl") +include("utils/chunks.jl") +include("utils/logging.jl") include("submission.jl") -include("chunks.jl") include("memory-spaces.jl") # Task scheduling @@ -85,15 +90,15 @@ include("stream.jl") include("stream-buffers.jl") include("stream-transfer.jl") +# File IO +include("file-io.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") include("array/map-reduce.jl") include("array/copy.jl") - -# File IO -include("file-io.jl") - +include("array/random.jl") include("array/operators.jl") include("array/indexing.jl") include("array/setindex.jl") @@ -104,19 +109,19 @@ include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") include("array/lu.jl") -include("array/random.jl") import KernelAbstractions, Adapt # GPU include("gpu.jl") -# Logging and Visualization +# Logging +include("utils/logging-events.jl") + +# Visualization include("visualization.jl") include("ui/gantt-common.jl") include("ui/gantt-text.jl") -include("utils/logging-events.jl") -include("utils/logging.jl") include("utils/viz.jl") """ diff --git a/src/argument.jl b/src/argument.jl new file mode 100644 index 000000000..94246a75e --- /dev/null +++ b/src/argument.jl @@ -0,0 +1,46 @@ +mutable struct ArgPosition + positional::Bool + idx::Int + kw::Symbol +end +ArgPosition() = ArgPosition(true, 0, :NULL) +ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw) +ispositional(pos::ArgPosition) = pos.positional +iskw(pos::ArgPosition) = !pos.positional +raw_position(pos::ArgPosition) = ispositional(pos) ? pos.idx : pos.kw +function pos_idx(pos::ArgPosition) + @assert pos.positional + @assert pos.idx > 0 + @assert pos.kw == :NULL + return pos.idx +end +function pos_kw(pos::ArgPosition) + @assert !pos.positional + @assert pos.idx == 0 + @assert pos.kw != :NULL + return pos.kw +end +mutable struct Argument + pos::ArgPosition + value +end +Argument(pos::Integer, value) = Argument(ArgPosition(true, pos, :NULL), value) +Argument(kw::Symbol, value) = Argument(ArgPosition(false, 0, kw), value) +ispositional(arg::Argument) = ispositional(arg.pos) +iskw(arg::Argument) = iskw(arg.pos) +pos_idx(arg::Argument) = pos_idx(arg.pos) +pos_kw(arg::Argument) = pos_kw(arg.pos) +raw_position(arg::Argument) = raw_position(arg.pos) +value(arg::Argument) = arg.value +valuetype(arg::Argument) = typeof(arg.value) +Base.iterate(arg::Argument) = (arg.pos, true) +function Base.iterate(arg::Argument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end + +Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) +chunktype(arg::Argument) = chunktype(value(arg)) diff --git a/src/array/alloc.jl b/src/array/alloc.jl index 921e1dd03..a95e070ae 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -86,7 +86,7 @@ function stage(ctx, a::AllocateArray) args = a.want_index ? (i, size(x)) : (size(x),) if isnothing(a.procgrid) - scope = get_options(:compute_scope, get_options(:scope, DefaultScope())) + scope = get_compute_scope() else scope = ExactScope(a.procgrid[CartesianIndex(mod1.(Tuple(I), size(a.procgrid))...)]) end diff --git a/src/array/darray.jl b/src/array/darray.jl index 4a655d6ce..af51fb6af 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -173,7 +173,7 @@ domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) stage(ctx, c::DArray) = c -function Base.collect(d::DArray; tree=false) +function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} a = fetch(d) if isempty(d.chunks) return Array{eltype(d)}(undef, size(d)...) @@ -183,6 +183,13 @@ function Base.collect(d::DArray; tree=false) return fetch(a.chunks[1]) end + if copyto + C = Array{T,N}(undef, size(a)) + DC = view(C, Blocks(size(a)...)) + copyto!(DC, a) + return C + end + dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)] if tree collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks))) @@ -458,7 +465,7 @@ function stage(ctx::Context, d::Distribute) # TODO: fix hashing #hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data))) if isnothing(d.procgrid) - scope = get_options(:compute_scope, get_options(:scope, DefaultScope())) + scope = get_compute_scope() else scope = ExactScope(d.procgrid[CartesianIndex(mod1.(Tuple(I), size(d.procgrid))...)]) end @@ -478,7 +485,7 @@ function stage(ctx::Context, d::Distribute) #hash = uhash(c, Base.hash(Distribute, Base.hash(d.data))) c = d.domainchunks[I] if isnothing(d.procgrid) - scope = get_options(:compute_scope, get_options(:scope, DefaultScope())) + scope = get_compute_scope() else scope = ExactScope(d.procgrid[CartesianIndex(mod1.(Tuple(I), size(d.procgrid))...)]) end diff --git a/src/cancellation.jl b/src/cancellation.jl index 63993a0e0..748c2aacd 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -98,21 +98,32 @@ function _cancel!(state, tid, force, graceful, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = DTaskFailedException(task, task, InterruptException()) - state.errored[task] = true - Sch.set_failed!(state, task) + ex = DTaskFailedException(task, task, InterruptException()) + Sch.store_result!(state, task, ex; error=true) + Sch.finish_failed!(state, task, task) + end + if tid === nothing + empty!(state.ready) + else + idx = findfirst(t->t.id == tid, state.ready) + idx !== nothing && deleteat!(state.ready, idx) end - empty!(state.ready) # Cancel waiting tasks for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = DTaskFailedException(task, task, InterruptException()) - state.errored[task] = true - Sch.set_failed!(state, task) + ex = DTaskFailedException(task, task, InterruptException()) + Sch.store_result!(state, task, ex; error=true) + Sch.finish_failed!(state, task, task) + end + if tid === nothing + empty!(state.waiting) + else + if haskey(state.waiting, tid) + delete!(state.waiting, tid) + end end - empty!(state.waiting) # Cancel running tasks at the processor level wids = unique(map(root_worker_id, values(state.running_on))) @@ -126,7 +137,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) for (tid, task) in istate.tasks _tid !== nothing && tid != _tid && continue task_spec = istate.task_specs[tid] - Tf = task_spec[6] + Tf = task_spec.Tf Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true @@ -136,13 +147,13 @@ function _cancel!(state, tid, force, graceful, halt_sch) else @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task - task_occupancy = task_spec[4] - time_util = task_spec[2] + task_occupancy = task_spec.est_occupancy + time_util = task_spec.est_time_util istate.proc_occupancy[] -= task_occupancy istate.time_pressure[] -= time_util push!(istate.cancelled, tid) to_proc = istate.proc - put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + put!(istate.return_queue, Sch.TaskResult(myid(), to_proc, tid, InterruptException(), nothing)) cancel!(istate.cancel_tokens[tid]; graceful) end end @@ -155,6 +166,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) return end end + put!(state.chan, Sch.RescheduleSignal()) if halt_sch unlock(state.lock) @@ -165,7 +177,7 @@ function _cancel!(state, tid, force, graceful, halt_sch) # Halt the scheduler @dagdebug nothing :cancel "Halting the scheduler" notify(state.halt) - put!(state.chan, (1, nothing, nothing, (Sch.SchedulerHaltedException(), nothing))) + put!(state.chan, Sch.TaskResult(1, OSProc(), 0, Sch.SchedulerHaltedException(), nothing)) # Wait for the scheduler to halt @dagdebug nothing :cancel "Waiting for scheduler to halt" diff --git a/src/chunks.jl b/src/chunks.jl index 1eb56714e..03bdfb65d 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -50,13 +50,10 @@ mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} handle::H processor::P scope::S - persist::Bool end domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype -persist!(t::Chunk) = (t.persist=true; t) -shouldpersist(p::Chunk) = t.persist processor(c::Chunk) = c.processor affinity(c::Chunk) = affinity(c.handle) @@ -65,8 +62,6 @@ is_task_or_chunk(c::Chunk) = true Base.:(==)(c1::Chunk, c2::Chunk) = c1.handle == c2.handle Base.hash(c::Chunk, x::UInt64) = hash(c.handle, hash(Chunk, x)) -Adapt.adapt_storage(::FetchAdaptor, x::Chunk) = fetch(x) - collect_remote(chunk::Chunk) = move(chunk.processor, OSProc(), poolget(chunk.handle)) @@ -113,181 +108,6 @@ affinity(r::DRef) = OSProc(r.owner)=>r.size # see #295 affinity(r::FileRef) = OSProc(1)=>r.size -### Mutation - -function _mutable_inner(@nospecialize(f), proc, scope) - result = f() - return Ref(Dagger.tochunk(result, proc, scope)) -end - -""" - mutable(f::Base.Callable; worker, processor, scope) -> Chunk - -Calls `f()` on the specified worker or processor, returning a `Chunk` -referencing the result with the specified scope `scope`. -""" -function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) - if processor === nothing - if worker === nothing - processor = OSProc() - else - processor = OSProc(worker) - end - else - @assert worker === nothing "mutable: Can't mix worker and processor" - end - if scope === nothing - scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) - end - return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] -end - -""" - @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() - -Helper macro for [`mutable()`](@ref). -""" -macro mutable(exs...) - opts = esc.(exs[1:end-1]) - ex = exs[end] - quote - let f = @noinline ()->$(esc(ex)) - $mutable(f; $(opts...)) - end - end -end - -""" -Maps a value to one of multiple distributed "mirror" values automatically when -used as a thunk argument. Construct using `@shard` or `shard`. -""" -struct Shard - chunks::Dict{Processor,Chunk} -end - -""" - shard(f; kwargs...) -> Chunk{Shard} - -Executes `f` on all workers in `workers`, wrapping the result in a -process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these -`Chunk`s on the current worker. - -Keyword arguments: -- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. -- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. -- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. -""" -function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) - if procs === nothing - if workers !== nothing - procs = [OSProc(w) for w in workers] - else - procs = lock(Sch.eager_context()) do - copy(Sch.eager_context().procs) - end - end - if per_thread - _procs = ThreadProc[] - for p in procs - append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) - end - procs = _procs - end - else - if workers !== nothing - throw(ArgumentError("Cannot combine `procs` and `workers`")) - elseif per_thread - throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) - end - end - isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) - shard_running_dict = Dict{Processor,DTask}() - for proc in procs - scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) - thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) - shard_running_dict[proc] = thunk - end - shard_dict = Dict{Processor,Chunk}() - for proc in procs - shard_dict[proc] = fetch(shard_running_dict[proc])[] - end - return Shard(shard_dict) -end - -"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." -macro shard(exs...) - opts = esc.(exs[1:end-1]) - ex = exs[end] - quote - let f = @noinline ()->$(esc(ex)) - $shard(f; $(opts...)) - end - end -end - -function move(from_proc::Processor, to_proc::Processor, shard::Shard) - # Match either this proc or some ancestor - # N.B. This behavior may bypass the piece's scope restriction - proc = to_proc - if haskey(shard.chunks, proc) - return move(from_proc, to_proc, shard.chunks[proc]) - end - parent = Dagger.get_parent(proc) - while parent != proc - proc = parent - parent = Dagger.get_parent(proc) - if haskey(shard.chunks, proc) - return move(from_proc, to_proc, shard.chunks[proc]) - end - end - - throw(KeyError(to_proc)) -end -Base.iterate(s::Shard) = iterate(values(s.chunks)) -Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) -Base.length(s::Shard) = length(s.chunks) - -### Core Stuff - -""" - tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, kwargs...) -> Chunk - -Create a chunk from data `x` which resides on `proc` and which has scope -`scope`. - -`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a -`Chunk`) which will be used to manage the reference contained in the `Chunk` -generated by this function. If `device` is `nothing` (the default), the data -will be inspected to determine if it's safe to serialize; if so, the default -MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will -be used. - -All other kwargs are passed directly to `MemPool.poolset`. -""" -function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S} - if device === nothing - device = if Sch.walk_storage_safe(x) - MemPool.GLOBAL_DEVICE[] - else - MemPool.CPURAMDevice() - end - end - ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist) -end -tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x - -function savechunk(data, dir, f) - sz = open(joinpath(dir, f), "w") do io - serialize(io, MemPool.MMWrap(data)) - return position(io) - end - fr = FileRef(f, sz) - proc = OSProc() - scope = AnyScope() # FIXME: Scoped to this node - Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) -end - struct WeakChunk wid::Int id::Int @@ -302,12 +122,10 @@ function unwrap_weak_checked(c::WeakChunk) @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" return cw end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false is_task_or_chunk(c::WeakChunk) = true Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = error("Cannot serialize a WeakChunk") - -Base.@deprecate_binding AbstractPart Union{Chunk, Thunk} -Base.@deprecate_binding Part Chunk -Base.@deprecate parts(args...) chunks(args...) -Base.@deprecate part(args...) tochunk(args...) -Base.@deprecate parttype(args...) chunktype(args...) +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/src/compute.jl b/src/compute.jl index 093b527f4..f655cbac4 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -2,40 +2,31 @@ export compute, debug_compute ###### Scheduler ####### -compute(x; options=nothing) = compute(Context(global_context()), x; options=options) -compute(ctx, c::Chunk; options=nothing) = c +compute(x; options::Union{SchedulerOptions,Nothing}=nothing) = + compute(global_context(), x; options) +compute(ctx, c::Chunk; options::Union{SchedulerOptions,Nothing}=nothing) = c -collect(ctx::Context, t::Thunk; options=nothing) = - collect(ctx, compute(ctx, t; options=options); options=options) -collect(d::Union{Chunk,Thunk}; options=nothing) = - collect(Context(global_context()), d; options=options) +collect(ctx::Context, t::Thunk; options::Union{SchedulerOptions,Nothing}=nothing) = + collect(ctx, compute(ctx, t; options); options) +collect(d::Union{Chunk,Thunk}; options::Union{SchedulerOptions,Nothing}=nothing) = + collect(global_context(), d; options) abstract type Computation end """ - compute(ctx::Context, d::Thunk; options=nothing) -> Chunk + compute(ctx::Context, d::Thunk; options::Union{SchedulerOptions,Nothing}=nothing) -> Chunk Compute a Thunk - creates the DAG, assigns ranks to nodes for tie breaking and runs the scheduler with the specified options. Returns a Chunk which references the result. """ -compute(ctx::Context, d::Thunk; options=nothing) = - Sch.compute_dag(ctx, d; options=options) - -function debug_compute(ctx::Context, args...; profile=false, options=nothing) - @time result = compute(ctx, args...; options=options) - get_logs!(ctx.log_sink), result -end - -function debug_compute(arg; profile=false, options=nothing) - ctx = Context(global_context()) - dbgctx = Context(procs(ctx), LocalEventLog(), profile) - debug_compute(dbgctx, arg; options=options) +function compute(ctx::Context, d::Thunk; options::Union{SchedulerOptions,Nothing}=nothing) + if options === nothing + options = SchedulerOptions() + end + return Sch.compute_dag(ctx, d, options) end -Base.@deprecate gather(ctx, x) collect(ctx, x) -Base.@deprecate gather(x) collect(x) - function get_type(s::String) local T for t in split(s, ".") @@ -67,7 +58,7 @@ function dependents(node::Thunk) if !haskey(deps, next) deps[next] = Set{Thunk}() end - for inp in next.syncdeps + for inp in next.options.syncdeps if istask(inp) || (inp isa Chunk) s = get!(()->Set{Thunk}(), deps, inp) push!(s, next) @@ -135,7 +126,7 @@ function order(node::Thunk, ndeps) haskey(output, next) && continue s += 1 output[next] = s - parents = collect(filter(istask, next.syncdeps)) + parents = collect(filter(istask, next.options.syncdeps)) if !isempty(parents) # If parents is empty, sort! should be a no-op, but raises an ambiguity error # when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings diff --git a/src/context.jl b/src/context.jl index f9f16fbe4..e4752d4ed 100644 --- a/src/context.jl +++ b/src/context.jl @@ -18,7 +18,6 @@ mutable struct Context proc_notify::Threads.Condition log_sink::Any profile::Bool - options end function Context(procs::Vector{P}=Processor[OSProc(w) for w in procs()]; @@ -28,13 +27,15 @@ function Context(procs::Vector{P}=Processor[OSProc(w) for w in procs()]; if log_file !== nothing @warn "`log_file` is no longer supported\nPlease instead load `GraphViz.jl` and use `render_logs(logs, :graphviz)`." end - Context(procs, proc_lock, proc_notify, log_sink, - profile, options) + if options !== nothing + @warn "`options` is no longer supported\nPlease instead pass the options to `compute`/`collect` as a keyword argument." + end + return Context(procs, proc_lock, proc_notify, + log_sink, profile) end Context(xs::Vector{Int}; kwargs...) = Context(map(OSProc, xs); kwargs...) Context(ctx::Context, xs::Vector=copy(procs(ctx))) = # make a copy - Context(xs; log_sink=ctx.log_sink, - profile=ctx.profile, options=ctx.options) + Context(xs; log_sink=ctx.log_sink, profile=ctx.profile) const GLOBAL_CONTEXT = Ref{Context}() function global_context() diff --git a/src/datadeps.jl b/src/datadeps.jl index a4f64d2bd..eee2a6ece 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -268,9 +268,9 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() # Track the task's arguments and access patterns - for (idx, (pos, arg)) in enumerate(spec.args) + for (idx, _arg) in enumerate(spec.fargs) # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(arg) + arg, deps = unwrap_inout(value(_arg)) # Unwrap the Chunk underlying any DTask arguments arg = arg isa DTask ? fetch(arg; raw=true) : arg @@ -471,8 +471,8 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_w = root_worker_id(dest_space) ctx = Sch.eager_context() id = rand(Int) - timespan_start(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data)) dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data + timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) data_converted = move(from_proc, to_proc, data) data_chunk = tochunk(data_converted, to_proc) @assert processor(data_chunk) in processors(dest_space) @@ -482,7 +482,7 @@ function generate_slot!(state::DataDepsState, dest_space, data) end return data_chunk end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data=dest_space_args[data])) + timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) end return dest_space_args[data] end @@ -516,7 +516,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Get the set of all processors to be scheduled on all_procs = Processor[] - scope = get_options(:scope, DefaultScope()) + scope = get_compute_scope() for w in procs() append!(all_procs, get_processors(OSProc(w))) end @@ -596,7 +596,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) scheduler = queue.scheduler if scheduler == :naive - raw_args = map(arg->tochunk(last(arg)), spec.args) + raw_args = map(arg->tochunk(value(arg)), spec.fargs) our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -611,13 +611,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end end elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, arg), spec.args)) do arg + raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg arg_chunk = tochunk(last(arg)) # Only the owned slot is valid # FIXME: Track up-to-date copies and pass all of those return arg_chunk => data_locality[arg] end - f_chunk = tochunk(spec.f) + f_chunk = tochunk(value(f)) our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -658,7 +658,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # FIXME: Pressure should be decreased by pressure of syncdeps on same processor pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure elseif scheduler == :ultra - args = Base.mapany(spec.args) do arg + args = Base.mapany(spec.fargs) do arg pos, data = arg data, _ = unwrap_inout(data) if data isa DTask @@ -666,7 +666,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end return pos => tochunk(data) end - f_chunk = tochunk(spec.f) + f_chunk = tochunk(value(f)) task_time = remotecall_fetch(1, f_chunk, args) do f, args Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -724,33 +724,32 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - task_scope = get(spec.options, :scope, AnyScope()) + task_scope = @something(spec.options.scope, AnyScope()) our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) if our_scope isa InvalidScope throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) end - spec.f = move(ThreadProc(myid(), 1), our_proc, spec.f) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f))) Scheduling: $our_proc ($our_space)" + f = spec.fargs[1] + f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis - task_args = copy(spec.args) + task_args = map(copy, spec.fargs) # Copy args from local to remote - for (idx, (pos, arg)) in enumerate(task_args) - # Is the data written previously or now? - arg, deps = unwrap_inout(arg) + for (idx, _arg) in enumerate(task_args) + # Is the data writeable? + arg, deps = unwrap_inout(value(_arg)) arg = arg isa DTask ? fetch(arg; raw=true) : arg if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)" - spec.args[idx] = pos => arg + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" + spec.fargs[idx].value = arg continue end - - # Is the data writeable? if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (non-writeable)" - spec.args[idx] = pos => arg + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" + spec.fargs[idx].value = arg continue end @@ -765,20 +764,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue) nonlocal = our_space != data_space if nonlocal # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do generate_slot!(state, data_space, arg) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) add_writer!(state, ainfo, copy_to, write_num) astate.data_locality[ainfo] = our_space else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Skipped copy-to (local): $data_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" end end else @@ -786,40 +785,44 @@ function distribute_tasks!(queue::DataDepsTaskQueue) nonlocal = our_space != data_space if nonlocal # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Enqueueing copy-to: $data_space => $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do generate_slot!(state, data_space, arg) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] $(length(copy_to_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) add_writer!(state, arg, copy_to, write_num) astate.data_locality[arg] = our_space else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (local): $data_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" end end - spec.args[idx] = pos => arg_remote + spec.fargs[idx].value = arg_remote end write_num += 1 # Validate that we're not accidentally performing a copy - for (idx, (_, arg)) in enumerate(spec.args) - _, deps = unwrap_inout(task_args[idx][2]) + for (idx, _arg) in enumerate(spec.fargs) + _, deps = unwrap_inout(value(task_args[idx])) # N.B. We only do this check when the argument supports in-place # moves, because for the moment, we are not guaranteeing updates or # write-back of results + arg = value(_arg) if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" end end # Calculate this task's syncdeps - syncdeps = get(Set{Any}, spec.options, :syncdeps) + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{Any}() + end + syncdeps = spec.options.syncdeps for (idx, (_, arg)) in enumerate(task_args) arg, deps = unwrap_inout(arg) arg = arg isa DTask ? fetch(arg; raw=true) : arg @@ -829,28 +832,27 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for (dep_mod, _, writedep) in deps ainfo = aliasing(astate, arg, dep_mod) if writedep - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" get_write_deps!(state, ainfo, task, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" get_read_deps!(state, ainfo, task, write_num, syncdeps) end end else if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" get_write_deps!(state, arg, task, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" get_read_deps!(state, arg, task, write_num, syncdeps) end end end - @dagdebug nothing :spawn_datadeps "($(repr(spec.f))) $(length(syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" # Launch user's task - task_scope = our_scope - spec.options = merge(spec.options, (;syncdeps, scope=task_scope)) + spec.options.scope = our_scope enqueue!(upper_queue, spec=>task) # Update read/write tracking for arguments @@ -862,7 +864,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for (dep_mod, _, writedep) in deps ainfo = aliasing(astate, arg, dep_mod) if writedep - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Set as owner" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" add_writer!(state, ainfo, task, write_num) else add_reader!(state, ainfo, task, write_num) @@ -870,7 +872,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end else if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Set as owner" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" add_writer!(state, arg, task, write_num) else add_reader!(state, arg, task, write_num) diff --git a/src/dtask.jl b/src/dtask.jl index b597db5fa..b74774287 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -26,19 +26,6 @@ function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) end Base.put!(t::ThunkFuture, x; error=false) = put!(t.future, (error, x)) -""" - Options(::NamedTuple) - Options(; kwargs...) - -Options for thunks and the scheduler. See [Task Spawning](@ref) for more -information. -""" -struct Options - options::NamedTuple -end -Options(;options...) = Options((;options...)) -Options(options...) = Options((;options...)) - """ DTaskMetadata @@ -61,10 +48,9 @@ mutable struct DTask uid::UInt future::ThunkFuture metadata::DTaskMetadata - finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) + DTask(uid, future, metadata) = new(uid, future, metadata) end const EagerThunk = DTask @@ -121,16 +107,6 @@ function Base.show(io::IO, t::DTask) end istask(t::DTask) = true -"When finalized, cleans-up the associated `DTask`." -mutable struct DTaskFinalizer - uid::UInt - function DTaskFinalizer(uid) - x = new(uid) - finalizer(Sch.eager_cleanup, x) - x - end -end - const EAGER_ID_COUNTER = Threads.Atomic{UInt64}(1) function eager_next_id() if myid() == 1 diff --git a/src/file-io.jl b/src/file-io.jl index efe80c59f..1407fbced 100644 --- a/src/file-io.jl +++ b/src/file-io.jl @@ -74,296 +74,3 @@ Base.show(io::IO, file::File) = print(io, "Dagger.File(path=\"$(file.path)\")") function move(from_proc::Processor, to_proc::Processor, file::File) return move(from_proc, to_proc, file.chunk) end - -""" - FileReader - -Used as a `Chunk` handle for reading a file, starting at a given offset. -""" -mutable struct FileReader{T} - file::AbstractString - chunktype::Type{T} - data_offset::Int - mmap::Bool -end - -""" - save(io::IO, val) - -Save a value into the IO buffer. In the case of arrays and sparse -matrices, this will save it in a memory-mappable way. - -`load(io::IO, t::Type, domain)` will load the object given its domain -""" -function save(ctx, io::IO, val) - error("Save method for $(typeof(val)) not defined") -end - - -###### Save chunks ###### - -const PARTSPEC = 0x00 -const CAT = 0x01 - -# subparts are saved as Parts - -""" - save(ctx, chunk::Union{Chunk, Thunk}, file_path::AbsractString) - -Save a chunk to a file at `file_path`. -""" -function save(ctx, chunk::Union{Chunk, Thunk}, file_path::AbstractString) - open(file_path, "w") do io - save(ctx, io, chunk, file_path) - end -end - -""" - save(ctx, chunk, file_path) - -Special case distmem writing - write to disk on the process with the chunk. -""" -function save(ctx, chunk::Chunk{X,DRef}, file_path::AbstractString) where X - pid = chunk.handle.where - - remotecall_fetch(pid, file_path, chunk.handle) do path, rref - open(path, "w") do io - save(ctx, io, chunk, file_path) - end - end -end - -function save(ctx, io::IO, chunk::Chunk, file_path) - meta_io = IOBuffer() - - serialize(meta_io, (chunktype(chunk), domain(chunk))) - meta = take!(meta_io) - - write(io, PARTSPEC) - write(io, length(meta)) - write(io, meta) - data_offset = position(io) - - save(ctx, io, collect(ctx, chunk)) - - Chunk(chunktype(chunk), domain(chunk), FileReader(file_path, chunktype(chunk), data_offset, false), false) -end - -function save(ctx, io::IO, chunk::DArray, file_path::AbstractString, saved_parts::AbstractArray) - - metadata = (chunktype(chunk), domain(chunk), saved_parts) - - # save yourself - write(io, CAT) - serialize(io, metadata) - - DArray(metadata...) - # write each child -end - - -function save(ctx, io::IO, chunk::DArray, file_path) - dir_path = file_path*"_data" - if !isdir(dir_path) - mkdir(dir_path) - end - - # save the chunks - saved_parts = [save(ctx, c, joinpath(dir_path, lpad(i, 4, "0"))) - for (i, c) in enumerate(chunks(chunk))] - - save(ctx, io, chunk, file_path, saved_parts) - # write each child -end - -function save(ctx, chunk::Chunk{X, FileReader}, file_path::AbstractString) where X - if abspath(file_path) == abspath(chunk.reader.file) - chunk - else - cp(chunk.reader.file, file_path) - Chunk(chunktype(chunk), domain(chunk), - FileReader(file_path, chunktype(chunk), - chunk.reader.data_offset, false), false) - end -end - -save(chunk::Union{Chunk, Thunk}, file_path::AbstractString) = save(Context(global_context()), chunk, file_path) - - - -###### Load chunks ###### - -""" - load(ctx::Context, file_path) - -Load an Union{Chunk, Thunk} from a file. -""" -function load(ctx::Context, file_path::AbstractString; mmap=false) - - open(file_path) do f - part_typ = read(f, UInt8) - if part_typ == PARTSPEC - c = load(ctx, Chunk, file_path, mmap, f) - elseif part_typ == CAT - c = load(ctx, DArray, file_path, mmap, f) - else - error("Could not determine chunk type") - end - end - c -end - -""" - load(ctx::Context, ::Type{Chunk}, fpath, io) - -Load a Chunk object from a file, the file path -is required for creating a FileReader object -""" -function load(ctx::Context, ::Type{Chunk}, fname, mmap, io) - meta_len = read(io, Int) - io = IOBuffer(read(io, meta_len)) - - (T, dmn, sz) = deserialize(io) - - DArray(Chunk(T, dmn, sz, - FileReader(fname, T, meta_len+1, mmap), false)) -end - -function load(ctx::Context, ::Type{DArray}, file_path, mmap, io) - dir_path = file_path*"_data" - - metadata = deserialize(io) - c = DArray(metadata...) - for p in chunks(c) - if isa(p.handle, FileReader) - p.handle.mmap = mmap - end - end - DArray(c) -end - - -###### Save and Load for actual data ##### - -function save(ctx::Context, io::IO, m::Array) - write(io, reinterpret(UInt8, m, (sizeof(m),))) - m -end - -function save(ctx::Context, io::IO, m::BitArray) - save(ctx, io, convert(Array{Bool}, m)) -end - -function collect(ctx::Context, c::Chunk{X,FileReader{T}}) where {X,T<:Array} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - arr = h.mmap ? Mmap.mmap(io, h.chunktype, size(c.domain)) : - reshape(reinterpret(eltype(T), read(io)), size(c.domain)) - close(io) - arr -end - -function collect(ctx::Context, c::Chunk{X, FileReader{T}}) where {X,T<:BitArray} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - - arr = h.mmap ? Mmap.mmap(io, Bool, size(c.domain)) : - reshape(reinterpret(Bool, read(io)), size(c.domain)) - close(io) - arr -end - -function save(ctx::Context, io::IO, m::SparseMatrixCSC{Tv,Ti}) where {Tv, Ti} - write(io, m.m) - write(io, m.n) - write(io, length(m.nzval)) - - typ_io = IOBuffer() - serialize(typ_io, (Tv, Ti)) - buf = take!(typ_io) - write(io, sizeof(buf)) - write(io, buf) - - write(io, reinterpret(UInt8, m.colptr, (sizeof(m.colptr),))) - write(io, reinterpret(UInt8, m.rowval, (sizeof(m.rowval),))) - write(io, reinterpret(UInt8, m.nzval, (sizeof(m.nzval),))) - m -end - -function collect(ctx::Context, c::Chunk{X, FileReader{T}}) where {X, T<:SparseMatrixCSC} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - - m = read(io, Int) - n = read(io, Int) - nnz = read(io, Int) - - typ_len = read(io, Int) - typ_bytes = read(io, typ_len) - (Tv, Ti) = deserialize(IOBuffer(typ_bytes)) - - pos = position(io) - colptr = Mmap.mmap(io, Vector{Ti}, (n+1,), pos) - - pos += sizeof(Ti)*(n+1) - rowval = Mmap.mmap(io, Vector{Ti}, (nnz,), pos) - - pos += sizeof(Ti)*nnz - nnzval = Mmap.mmap(io, Vector{Tv}, (nnz,), pos) - close(io) - - SparseMatrixCSC(m, n, colptr, rowval, nnzval) -end - -function getsub(ctx::Context, c::Chunk{X,FileReader{T}}, d) where {X,T<:AbstractArray} - Chunk(collect(ctx, c)[d]) -end - - -#### Save computation - -struct Save <: Computation - input - name::AbstractString -end - -function save(p::Computation, name::AbstractString) - Save(p, name) -end - -function stage(ctx::Context, s::Save) - x = stage(ctx, s.input) - dir_path = s.name * "_data" - if !isdir(dir_path) - mkdir(dir_path) - end - function save_part(idx, data) - p = tochunk(data) - path = joinpath(dir_path, lpad(idx, 4, "0")) - saved = save(ctx, p, path) - - # release reference created for the purpose of save - release_token(p.handle) - saved - end - - saved_parts = similar(chunks(x), Thunk) - for i=1:length(chunks(x)) - saved_parts[i] = Thunk(save_part, i, chunks(x)[i]) - end - - sz = size(chunks(x)) - function save_cat_meta(chunks...) - f = open(s.name, "w") - saved_parts = reshape(Union{Chunk, Thunk}[c for c in chunks], sz) - res = save(ctx, f, x, s.name, saved_parts) - close(f) - res - end - - # The DAG has to block till saving is complete. - res = Thunk(save_cat_meta, saved_parts...; meta=true) -end diff --git a/src/options.jl b/src/options.jl index 00196dd59..12e8c3dcf 100644 --- a/src/options.jl +++ b/src/options.jl @@ -1,3 +1,192 @@ +# Task options + +""" + Options + +Stores per-task options to be passed to the scheduler. + +# Arguments +- `propagates::Vector{Symbol}`: The set of option names that will be propagated by this task to tasks that it spawns. +- `processor::Processor`: The processor associated with this task's function. Generally ignored by the scheduler. +- `compute_scope::AbstractScope`: The execution scope of the task, which determines where the task can be scheduled and executed. `scope` is another name for this option. +- `result_scope::AbstractScope`: The data scope of the task's result, which determines where the task's result can be accessed from. +- `single::Int=0`: (Deprecated) Force task onto worker with specified id. `0` disables this option. +- `proclist=nothing`: (Deprecated) Force task to use one or more processors that are instances/subtypes of a contained type. Alternatively, a function can be supplied, and the function will be called with a processor as the sole argument and should return a `Bool` result to indicate whether or not to use the given processor. `nothing` enables all default processors. +- `get_result::Bool=false`: Whether the worker should store the result directly (`true`) or as a `Chunk` (`false`) +- `meta::Bool=false`: When `true`, values are not `move`d, and are passed directly as `Chunk`, if they are not immediate values +- `syncdeps::Set{Any}`: Contains any additional tasks to synchronize with +- `time_util::Dict{Type,Any}`: Indicates the maximum expected time utilization for this task. Each keypair maps a processor type to the utilization, where the value can be a real (approximately the number of nanoseconds taken), or `MaxUtilization()` (utilizes all processors of this type). By default, the scheduler assumes that this task only uses one processor. +- `alloc_util::Dict{Type,UInt64}`: Indicates the maximum expected memory utilization for this task. Each keypair maps a processor type to the utilization, where the value is an integer representing approximately the maximum number of bytes allocated at any one time. +- `occupancy::Dict{Type,Real}`: Indicates the maximum expected processor occupancy for this task. Each keypair maps a processor type to the utilization, where the value can be a real between 0 and 1 (the occupancy ratio, where 1 is full occupancy). By default, the scheduler assumes that this task has full occupancy. +- `checkpoint=nothing`: If not `nothing`, uses the provided function to save the result of the task to persistent storage, for later retrieval by `restore`. +- `restore=nothing`: If not `nothing`, uses the provided function to return the (cached) result of this task, were it to execute. If this returns a `Chunk`, this task will be skipped, and its result will be set to the `Chunk`. If `nothing` is returned, restoring is skipped, and the task will execute as usual. If this function throws an error, restoring will be skipped, and the error will be displayed. +- `storage::Union{Chunk,Nothing}=nothing`: If not `nothing`, references a `MemPool.StorageDevice` which will be passed to `MemPool.poolset` internally when constructing `Chunk`s (such as when constructing the return value). The device must support `MemPool.CPURAMResource`. When `nothing`, uses `MemPool.GLOBAL_DEVICE[]`. +- `storage_root_tag::Any=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." +- `storage_leaf_tag::Union{MemPool.Tag,Nothing}=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." +- `storage_retain::Union{Bool,Nothing}=nothing`: The value of `retain` to pass to `MemPool.poolset` when constructing the result `Chunk`. `nothing` defaults to `false`. +- `name::Union{String,Nothing}=nothing`: If not `nothing`, annotates the task with a name for logging purposes. +- `stream_input_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the input buffer of the task. Defaults to 1. +- `stream_output_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the output buffer of the task. Defaults to 1. +- `stream_buffer_type::Union{Type,Nothing}=nothing`: (Streaming only) Specifies the type of buffer to use for the input and output buffers of the task. Defaults to `Dagger.ProcessRingBuffer`. +- `stream_max_evals::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the maximum number of times the task will be evaluated before returning a result. Defaults to infinite evaluations. +""" +Base.@kwdef mutable struct Options + propagates::Union{Vector{Symbol},Nothing} = nothing + + processor::Union{Processor,Nothing} = nothing + scope::Union{AbstractScope,Nothing} = nothing + compute_scope::Union{AbstractScope,Nothing} = scope + result_scope::Union{AbstractScope,Nothing} = nothing + single::Union{Int,Nothing} = nothing + proclist = nothing + + get_result::Union{Bool,Nothing} = nothing + meta::Union{Bool,Nothing} = nothing + + syncdeps::Union{Set{Any},Nothing} = nothing + + time_util::Union{Dict{Type,Any},Nothing} = nothing + alloc_util::Union{Dict{Type,UInt64},Nothing} = nothing + occupancy::Union{Dict{Type,Real},Nothing} = nothing + + checkpoint = nothing + restore = nothing + + storage::Union{Chunk,Nothing} = nothing + storage_root_tag = nothing + storage_leaf_tag::Union{MemPool.Tag,Nothing} = nothing + storage_retain::Union{Bool,Nothing} = nothing + + name::Union{String,Nothing} = nothing + + stream_input_buffer_amount::Union{Int,Nothing} = nothing + stream_output_buffer_amount::Union{Int,Nothing} = nothing + stream_buffer_type::Union{Type, Nothing} = nothing + stream_max_evals::Union{Int,Nothing} = nothing +end +Options(::Nothing) = Options() +function Options(old_options::NamedTuple) + new_options = Options() + options_merge!(new_options, old_options) + return new_options +end +function Base.copy(old_options::Options) + new_options = Options() + options_merge!(new_options, old_options) + return new_options +end +# Merge b -> a, where b takes precedence +function options_merge!(options::Options, source; override=true) + _options_merge!(options, source, override) + return options +end +_has_option(options::Union{Options,NamedTuple}, field) = hasproperty(options, field) +_get_option(options::Union{Options,NamedTuple}, field) = getproperty(options, field) +_set_option!(options::Union{Options,NamedTuple}, field, value) = setproperty!(options, field, value) +_has_option(options::Base.Pairs, field) = haskey(options, field) +_get_option(options::Base.Pairs, field) = options[field] +_set_option!(options::Base.Pairs, field, value) = error("Cannot set option in Base.Pairs") +@generated function _options_merge!(options, source, override) + ex = Expr(:block) + for field in fieldnames(Options) + push!(ex.args, quote + if _has_option(source, $(QuoteNode(field))) && _get_option(source, $(QuoteNode(field))) !== nothing + if override || _get_option(options, $(QuoteNode(field))) === nothing + _set_option!(options, + $(QuoteNode(field)), + _get_option(source, $(QuoteNode(field)))) + end + end + end) + end + return ex +end + +""" + populate_defaults!(opts::Options, sig::Vector{DataType}) -> Options + +Returns a `Options` with default values filled in for a function call with +signature `sig`, if the option was previously unspecified in `opts`. +""" +function populate_defaults!(opts::Options, sig) + maybe_default!(opts, Val{:propagates}(), sig) + maybe_default!(opts, Val{:processor}(), sig) + maybe_default!(opts, Val{:compute_scope}(), sig) + maybe_default!(opts, Val{:result_scope}(), sig) + maybe_default!(opts, Val{:single}(), sig) + maybe_default!(opts, Val{:proclist}(), sig) + maybe_default!(opts, Val{:get_result}(), sig) + maybe_default!(opts, Val{:meta}(), sig) + maybe_default!(opts, Val{:syncdeps}(), sig) + maybe_default!(opts, Val{:time_util}(), sig) + maybe_default!(opts, Val{:alloc_util}(), sig) + maybe_default!(opts, Val{:occupancy}(), sig) + maybe_default!(opts, Val{:checkpoint}(), sig) + maybe_default!(opts, Val{:restore}(), sig) + maybe_default!(opts, Val{:storage}(), sig) + maybe_default!(opts, Val{:storage_root_tag}(), sig) + maybe_default!(opts, Val{:storage_leaf_tag}(), sig) + maybe_default!(opts, Val{:storage_retain}(), sig) + maybe_default!(opts, Val{:name}(), sig) + maybe_default!(opts, Val{:stream_input_buffer_amount}(), sig) + maybe_default!(opts, Val{:stream_output_buffer_amount}(), sig) + maybe_default!(opts, Val{:stream_buffer_type}(), sig) + maybe_default!(opts, Val{:stream_max_evals}(), sig) + return opts +end +function maybe_default!(opts::Options, ::Val{opt}, sig::Signature) where opt + if getfield(opts, opt) === nothing + default_opt = get!(SIGNATURE_DEFAULT_CACHE[], (sig.hash_nokw, opt)) do + Dagger.default_option(Val{opt}(), sig.sig_nokw...) + end + setfield!(opts, opt, default_opt) + end +end + +struct BasicLFUCache{K,V} + cache::Dict{K,V} + freq::Dict{K,Int} + max_size::Int + + BasicLFUCache{K,V}(max_size::Int) where {K,V} = new(Dict{K,V}(), Dict{K,Int}(), max_size) +end +function Base.get!(f, cache::BasicLFUCache{K,V}, key::K) where {K,V} + if haskey(cache.cache, key) + cache.freq[key] += 1 + return cache.cache[key] + end + val = f()::V + cache.cache[key] = val + cache.freq[key] = 1 + if length(cache.cache) > cache.max_size + # Find the least frequently used key + _, lfu_key::K = findmin(cache.freq) + delete!(cache.cache, lfu_key) + delete!(cache.freq, lfu_key) + end + return val +end + +const SIGNATURE_DEFAULT_CACHE = TaskLocalValue{BasicLFUCache{Tuple{UInt,Symbol},Any}}(()->BasicLFUCache{Tuple{UInt,Symbol},Any}(256)) + +# SchedulerOptions integration + +function Dagger.options_merge!(topts::Options, sopts::SchedulerOptions) + function field_merge!(field) + if getfield(topts, field) === nothing && getfield(sopts, field) !== nothing + setfield!(topts, field, getfield(sopts, field)) + end + end + field_merge!(:single) + field_merge!(:proclist) + return topts +end +function Options(sopts::SchedulerOptions) + new_options = Options() + Dagger.options_merge!(new_options, sopts) + return new_options +end + # Scoped Options const options_context = ScopedValue{NamedTuple}(NamedTuple()) @@ -6,11 +195,12 @@ const options_context = ScopedValue{NamedTuple}(NamedTuple()) with_options(f, options::NamedTuple) -> Any with_options(f; options...) -> Any -Sets one or more options to the given values, executes `f()`, resets the +Sets one or more scoped options to the given values, executes `f()`, resets the options to their previous values, and returns the result of `f()`. This is the -recommended way to set options, as it only affects tasks spawned within its -scope. Note that setting an option here will propagate its value across Julia -or Dagger tasks spawned by `f()` or its callees (i.e. the options propagate). +recommended way to set scoped options, as it only affects tasks spawned within +its scope. Note that setting an option here will propagate its value across +Julia or Dagger tasks spawned by `f()` or its callees (i.e. the options +propagate). """ function with_options(f, options::NamedTuple) prev_options = options_context[] @@ -30,11 +220,13 @@ end get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any -Returns the value of the option named `key`. If `option` does not have a value set, then an error will be thrown, unless `default` is set, in which case it will be returned instead of erroring. +Returns the value of the scoped option named `key`. If `option` does not have a +value set, then an error will be thrown, unless `default` is set, in which case +it will be returned instead of erroring. get_options() -> NamedTuple -Returns a `NamedTuple` of all option key-value pairs. +Returns a `NamedTuple` of all scoped option key-value pairs. """ get_options() = options_context[] get_options(key::Symbol) = getproperty(get_options(), key) diff --git a/src/precompile.jl b/src/precompile.jl index 21a78d99f..874e70de5 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -23,7 +23,7 @@ # Halt scheduler notify(state.halt) - put!(state.chan, (1, nothing, nothing, (Sch.SchedulerHaltedException(), nothing))) + put!(state.chan, Sch.TaskResult(1, OSProc(), 0, Sch.SchedulerHaltedException(), nothing)) state = nothing # Wait for halt diff --git a/src/queue.jl b/src/queue.jl index 71789c6fb..b0b0ea45d 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,7 +1,6 @@ mutable struct DTaskSpec - f - args::Vector{Pair{Union{Symbol,Nothing},Any}} - options::NamedTuple + fargs::Vector{Argument} + options::Options end abstract type AbstractTaskQueue end @@ -41,17 +40,15 @@ end struct InOrderTaskQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue prev_tasks::Set{DTask} - InOrderTaskQueue(upper_queue) = new(upper_queue, - Set{DTask}()) + InOrderTaskQueue(upper_queue) = new(upper_queue, Set{DTask}()) end function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) # Add previously-enqueued task(s) to this task's syncdeps opts = spec.options - syncdeps = get(Set{Any}, opts, :syncdeps) + syncdeps = opts.syncdeps = @something(opts.syncdeps, Set()) for task in queue.prev_tasks push!(syncdeps, task) end - spec.options = merge(opts, (;syncdeps,)) end function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) if length(queue.prev_tasks) > 0 diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 0335fe3e0..f7ee904d5 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -10,16 +10,22 @@ end import MemPool import MemPool: DRef, StorageResource import MemPool: poolset, storage_capacity, storage_utilized -import Random: randperm +import Random: randperm, randperm! import Base: @invokelatest import ..Dagger -import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime -import ..Dagger: @dagdebug, @safe_lock_spin1 +import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek -import ..Dagger +import ..Dagger: ReusableCache, ReusableLinkedList, ReusableDict +import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks, @reuse_scope, @reuse_defer_cleanup + +import TimespanLogging + +import TaskLocalValues: TaskLocalValue +import ScopedValues: @with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -27,26 +33,15 @@ include("util.jl") include("fault-handler.jl") include("dynamic.jl") -mutable struct ProcessorCacheEntry - gproc::OSProc +struct TaskResult + pid::Int proc::Processor - next::ProcessorCacheEntry - - ProcessorCacheEntry(gproc::OSProc, proc::Processor) = new(gproc, proc) -end -Base.isequal(p1::ProcessorCacheEntry, p2::ProcessorCacheEntry) = - p1.proc === p2.proc -function Base.show(io::IO, entry::ProcessorCacheEntry) - entries = 1 - next = entry.next - while next !== entry - entries += 1 - next = next.next - end - print(io, "ProcessorCacheEntry(pid $(entry.gproc.pid), $(entry.proc), $entries entries)") + thunk_id::Int + result::Any + metadata::Union{NamedTuple,Nothing} end -const Signature = Vector{Any} +const AnyTaskResult = Union{RescheduleSignal, TaskResult} """ ComputeState @@ -64,12 +59,12 @@ Fields: - `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it - `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk` - `node_order::Any` - Function that returns the order of a thunk +- `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it - `worker_time_pressure::Dict{Int,Dict{Processor,UInt64}}` - Maps from worker ID to processor pressure - `worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}` - Maps from worker ID to storage resource pressure - `worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}` - Maps from worker ID to storage resource capacity - `worker_loadavg::Dict{Int,NTuple{3,Float64}}` - Worker load average - `worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}` - Communication channels between the scheduler and each worker -- `procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}` - Cached linked list of processors ready to be used - `signature_time_cost::Dict{Signature,UInt64}` - Cache of estimated CPU time (in nanoseconds) required to compute calls with the given signature - `signature_alloc_cost::Dict{Signature,UInt64}` - Cache of estimated CPU RAM (in bytes) required to compute calls with the given signature - `transfer_rate::Ref{UInt64}` - Estimate of the network transfer rate in bytes per second @@ -78,7 +73,7 @@ Fields: - `futures::Dict{Thunk, Vector{ThunkFuture}}` - Futures registered for waiting on the result of a thunk. - `errored::WeakKeyDict{Thunk,Bool}` - Indicates if a thunk's result is an error. - `thunks_to_delete::Set{Thunk}` - The list of `Thunk`s ready to be deleted upon completion. -- `chan::RemoteChannel{Channel{Any}}` - Channel for receiving completed thunks. +- `chan::RemoteChannel{Channel{AnyTaskResult}}` - Channel for receiving completed thunks. """ struct ComputeState uid::UInt64 @@ -91,12 +86,12 @@ struct ComputeState running_on::Dict{Thunk,OSProc} thunk_dict::Dict{Int, WeakThunk} node_order::Any + equiv_chunks::WeakKeyDict{DRef,Chunk} worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} worker_loadavg::Dict{Int,NTuple{3,Float64}} worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} - procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} transfer_rate::Ref{UInt64} @@ -105,7 +100,7 @@ struct ComputeState futures::Dict{Thunk, Vector{ThunkFuture}} errored::WeakKeyDict{Thunk,Bool} thunks_to_delete::Set{Thunk} - chan::RemoteChannel{Channel{Any}} + chan::RemoteChannel{Channel{AnyTaskResult}} end const UID_COUNTER = Threads.Atomic{UInt64}(1) @@ -121,12 +116,12 @@ function start_state(deps::Dict, node_order, chan) Dict{Thunk,OSProc}(), Dict{Int, WeakThunk}(), node_order, + WeakKeyDict{DRef,Chunk}(), Dict{Int,Dict{Processor,UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), Dict{Int,NTuple{3,Float64}}(), Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(), - Ref{Union{ProcessorCacheEntry,Nothing}}(nothing), Dict{Signature,UInt64}(), Dict{Signature,UInt64}(), Ref{UInt64}(1_000_000), @@ -139,7 +134,7 @@ function start_state(deps::Dict, node_order, chan) for k in sort(collect(keys(deps)), by=node_order) if istask(k) - waiting = Set{Thunk}(Iterators.filter(istask, k.syncdeps)) + waiting = Set{Thunk}(Iterators.filter(istask, k.options.syncdeps)) if isempty(waiting) push!(state.ready, k) else @@ -151,170 +146,6 @@ function start_state(deps::Dict, node_order, chan) state end -""" - SchedulerOptions - -Stores DAG-global options to be passed to the Dagger.Sch scheduler. - -# Arguments -- `single::Int=0`: (Deprecated) Force all work onto worker with specified id. - `0` disables this option. -- `proclist=nothing`: (Deprecated) Force scheduler to use one or more - processors that are instances/subtypes of a contained type. Alternatively, a - function can be supplied, and the function will be called with a processor as - the sole argument and should return a `Bool` result to indicate whether or not - to use the given processor. `nothing` enables all default processors. -- `allow_errors::Bool=true`: Allow thunks to error without affecting - non-dependent thunks. -- `checkpoint=nothing`: If not `nothing`, uses the provided function to save - the final result of the current scheduler invocation to persistent storage, for - later retrieval by `restore`. -- `restore=nothing`: If not `nothing`, uses the provided function to return the - (cached) final result of the current scheduler invocation, were it to execute. - If this returns a `Chunk`, all thunks will be skipped, and the `Chunk` will be - returned. If `nothing` is returned, restoring is skipped, and the scheduler - will execute as usual. If this function throws an error, restoring will be - skipped, and the error will be displayed. -""" -Base.@kwdef struct SchedulerOptions - single::Union{Int,Nothing} = nothing - proclist = nothing - allow_errors::Union{Bool,Nothing} = false - checkpoint = nothing - restore = nothing -end - -""" - ThunkOptions - -Stores Thunk-local options to be passed to the Dagger.Sch scheduler. - -# Arguments -- `single::Int=0`: (Deprecated) Force thunk onto worker with specified id. `0` - disables this option. -- `proclist=nothing`: (Deprecated) Force thunk to use one or more processors - that are instances/subtypes of a contained type. Alternatively, a function can - be supplied, and the function will be called with a processor as the sole - argument and should return a `Bool` result to indicate whether or not to use - the given processor. `nothing` enables all default processors. -- `time_util::Dict{Type,Any}`: Indicates the maximum expected time utilization - for this thunk. Each keypair maps a processor type to the utilization, where - the value can be a real (approximately the number of nanoseconds taken), or - `MaxUtilization()` (utilizes all processors of this type). By default, the - scheduler assumes that this thunk only uses one processor. -- `alloc_util::Dict{Type,UInt64}`: Indicates the maximum expected memory - utilization for this thunk. Each keypair maps a processor type to the - utilization, where the value is an integer representing approximately the - maximum number of bytes allocated at any one time. -- `occupancy::Dict{Type,Real}`: Indicates the maximum expected processor - occupancy for this thunk. Each keypair maps a processor type to the - utilization, where the value can be a real between 0 and 1 (the occupancy - ratio, where 1 is full occupancy). By default, the scheduler assumes that this - thunk has full occupancy. -- `allow_errors::Bool=true`: Allow this thunk to error without affecting - non-dependent thunks. -- `checkpoint=nothing`: If not `nothing`, uses the provided function to save - the result of the thunk to persistent storage, for later retrieval by - `restore`. -- `restore=nothing`: If not `nothing`, uses the provided function to return the - (cached) result of this thunk, were it to execute. If this returns a `Chunk`, - this thunk will be skipped, and its result will be set to the `Chunk`. If - `nothing` is returned, restoring is skipped, and the thunk will execute as - usual. If this function throws an error, restoring will be skipped, and the - error will be displayed. -- `storage::Union{Chunk,Nothing}=nothing`: If not `nothing`, references a - `MemPool.StorageDevice` which will be passed to `MemPool.poolset` internally - when constructing `Chunk`s (such as when constructing the return value). The - device must support `MemPool.CPURAMResource`. When `nothing`, uses - `MemPool.GLOBAL_DEVICE[]`. -- `storage_root_tag::Any=nothing`: If not `nothing`, - specifies the MemPool storage leaf tag to associate with the thunk's result. - This tag can be used by MemPool's storage devices to manipulate their behavior, - such as the file name used to store data on disk." -- `storage_leaf_tag::MemPool.Tag,Nothing}=nothing`: If not `nothing`, - specifies the MemPool storage leaf tag to associate with the thunk's result. - This tag can be used by MemPool's storage devices to manipulate their behavior, - such as the file name used to store data on disk." -- `storage_retain::Bool=false`: The value of `retain` to pass to - `MemPool.poolset` when constructing the result `Chunk`. -""" -Base.@kwdef struct ThunkOptions - single::Union{Int,Nothing} = nothing - proclist = nothing - time_util::Union{Dict{Type,Any},Nothing} = nothing - alloc_util::Union{Dict{Type,UInt64},Nothing} = nothing - occupancy::Union{Dict{Type,Real},Nothing} = nothing - allow_errors::Union{Bool,Nothing} = nothing - checkpoint = nothing - restore = nothing - storage::Union{Chunk,Nothing} = nothing - storage_root_tag = nothing - storage_leaf_tag::Union{MemPool.Tag,Nothing} = nothing - storage_retain::Bool = false -end - -""" - Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) -> ThunkOptions - -Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. -""" -function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt - - single = select_option(sopts.single, topts.single) - allow_errors = select_option(sopts.allow_errors, topts.allow_errors) - proclist = select_option(sopts.proclist, topts.proclist) - ThunkOptions(single, - proclist, - topts.time_util, - topts.alloc_util, - topts.occupancy, - allow_errors, - topts.checkpoint, - topts.restore, - topts.storage, - topts.storage_root_tag, - topts.storage_leaf_tag, - topts.storage_retain) -end -Base.merge(sopts::SchedulerOptions, ::Nothing) = - ThunkOptions(sopts.single, - sopts.proclist, - nothing, - nothing, - sopts.allow_errors) -""" - populate_defaults(opts::ThunkOptions, Tf, Targs) -> ThunkOptions - -Returns a `ThunkOptions` with default values filled in for a function of type -`Tf` with argument types `Targs`, if the option was previously unspecified in -`opts`. -""" -function populate_defaults(opts::ThunkOptions, Tf, Targs) - function maybe_default(opt::Symbol) - old_opt = getproperty(opts, opt) - if old_opt !== nothing - return old_opt - else - return Dagger.default_option(Val(opt), Tf, Targs...) - end - end - ThunkOptions( - maybe_default(:single), - maybe_default(:proclist), - maybe_default(:time_util), - maybe_default(:alloc_util), - maybe_default(:occupancy), - maybe_default(:allow_errors), - maybe_default(:checkpoint), - maybe_default(:restore), - maybe_default(:storage), - maybe_default(:storage_root_tag), - maybe_default(:storage_leaf_tag), - maybe_default(:storage_retain), - ) -end - # Eager scheduling include("eager.jl") @@ -323,7 +154,7 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) # Initialize pressure and capacity gproc = OSProc(p.pid) lock(state.lock) do @@ -360,7 +191,7 @@ function init_proc(state, p, log_sink) d = WORKER_MONITOR_CHANS[wid] for uid in keys(d) try - put!(d[uid], (wid, OSProc(wid), nothing, (ProcessExitedException(wid), nothing))) + put!(d[uid], TaskResult(wid, OSProc(wid), 0, ProcessExitedException(wid), nothing)) catch end end @@ -388,7 +219,7 @@ function init_proc(state, p, log_sink) # Setup dynamic listener dynamic_listener!(ctx, state, p.pid) - timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -404,7 +235,7 @@ end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) wid = p.pid - timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) + @maybelog ctx timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) delete!(WORKER_MONITOR_CHANS[wid], state.uid) @@ -424,7 +255,7 @@ function cleanup_proc(state, p, log_sink) end end - timespan_finish(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) + @maybelog ctx timespan_finish(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) end "Process-local condition variable (and lock) indicating task completion." @@ -445,11 +276,7 @@ Indicates a thunk that uses all processors of a given type. """ struct MaxUtilization end -function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) - if options === nothing - options = SchedulerOptions() - end - ctx.options = options +function compute_dag(ctx::Context, d::Thunk, options=SchedulerOptions()) if options.restore !== nothing try result = options.restore() @@ -463,7 +290,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) end end - chan = RemoteChannel(()->Channel(typemax(Int))) + chan = RemoteChannel(()->Channel{AnyTaskResult}(typemax(Int))) deps = dependents(d) ord = order(d, noffspring(deps)) @@ -472,24 +299,24 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) master = OSProc(myid()) - timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) + @maybelog ctx timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try scheduler_init(ctx, state, d, options, deps) finally - timespan_finish(ctx, :scheduler_init, (;uid=state.uid), master) + @maybelog ctx timespan_finish(ctx, :scheduler_init, (;uid=state.uid), master) end value, errored = try scheduler_run(ctx, state, d, options) finally # Always try to tear down the scheduler - timespan_start(ctx, :scheduler_exit, (;uid=state.uid), master) + @maybelog ctx timespan_start(ctx, :scheduler_exit, (;uid=state.uid), master) try scheduler_exit(ctx, state, options) catch err @error "Error when tearing down scheduler" exception=(err,catch_backtrace()) finally - timespan_finish(ctx, :scheduler_exit, (;uid=state.uid), master) + @maybelog ctx timespan_finish(ctx, :scheduler_exit, (;uid=state.uid), master) end end @@ -499,7 +326,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) return value end -function scheduler_init(ctx, state::ComputeState, d::Thunk, options, deps) +function scheduler_init(ctx, state::ComputeState, d::Thunk, options::SchedulerOptions, deps) # setup thunk_dict mappings for node in filter(istask, keys(deps)) state.thunk_dict[node.id] = WeakThunk(node) @@ -509,13 +336,13 @@ function scheduler_init(ctx, state::ComputeState, d::Thunk, options, deps) end # Initialize workers - @sync for p in procs_to_use(ctx) + @sync for p in procs_to_use(ctx, options) Threads.@spawn begin try init_proc(state, p, ctx.log_sink) catch err @error "Error initializing worker $p" exception=(err,catch_backtrace()) - remove_dead_proc!(ctx, state, p) + remove_dead_proc!(ctx, state, p, options) end end end @@ -528,14 +355,14 @@ function scheduler_init(ctx, state::ComputeState, d::Thunk, options, deps) # Listen for new workers Threads.@spawn begin try - monitor_procs_changed!(ctx, state) + monitor_procs_changed!(ctx, state, options) catch err @error "Error assigning workers" exception=(err,catch_backtrace()) end end end -function scheduler_run(ctx, state::ComputeState, d::Thunk, options) +function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOptions) @dagdebug nothing :global "Initializing scheduler" uid=state.uid safepoint(state) @@ -544,20 +371,26 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) while !isempty(state.ready) || !isempty(state.running) if !isempty(state.ready) # Nothing running, so schedule up to N thunks, 1 per N workers - schedule!(ctx, state) + @invokelatest schedule!(ctx, state, options) end - check_integrity(ctx) + check_workers_available(ctx, options) isempty(state.running) && continue - timespan_start(ctx, :take, (;uid=state.uid), nothing) + @maybelog ctx timespan_start(ctx, :take, (;uid=state.uid), nothing) @dagdebug nothing :take "Waiting for results" - chan_value = take!(state.chan) # get result of completed thunk - timespan_finish(ctx, :take, (;uid=state.uid), nothing) - if chan_value isa RescheduleSignal + tresult = take!(state.chan) # get result of completed thunk + @maybelog ctx timespan_finish(ctx, :take, (;uid=state.uid), nothing) + if tresult isa RescheduleSignal continue end - pid, proc, thunk_id, (res, metadata) = chan_value + + tresult::TaskResult + pid = tresult.pid + proc = tresult.proc + thunk_id = tresult.thunk_id + res = tresult.result + @dagdebug thunk_id :take "Got finished task" gproc = OSProc(pid) safepoint(state) @@ -568,30 +401,30 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) @warn "Worker $(pid) died, rescheduling work" # Remove dead worker from procs list - timespan_start(ctx, :remove_procs, (;uid=state.uid, worker=pid), nothing) - remove_dead_proc!(ctx, state, gproc) - timespan_finish(ctx, :remove_procs, (;uid=state.uid, worker=pid), nothing) + @maybelog ctx timespan_start(ctx, :remove_procs, (;uid=state.uid, worker=pid), nothing) + remove_dead_proc!(ctx, state, gproc, options) + @maybelog ctx timespan_finish(ctx, :remove_procs, (;uid=state.uid, worker=pid), nothing) - timespan_start(ctx, :handle_fault, (;uid=state.uid, worker=pid), nothing) + @maybelog ctx timespan_start(ctx, :handle_fault, (;uid=state.uid, worker=pid), nothing) handle_fault(ctx, state, gproc) - timespan_finish(ctx, :handle_fault, (;uid=state.uid, worker=pid), nothing) + @maybelog ctx timespan_finish(ctx, :handle_fault, (;uid=state.uid, worker=pid), nothing) return # effectively `continue` else - if something(ctx.options.allow_errors, false) || - something(unwrap_weak_checked(state.thunk_dict[thunk_id]).options.allow_errors, false) + if something(options.allow_errors, false) thunk_failed = true else throw(res) end end end - node = unwrap_weak_checked(state.thunk_dict[thunk_id]) + node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk + metadata = tresult.metadata if metadata !== nothing state.worker_time_pressure[pid][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[pid] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -599,8 +432,12 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ 2 end end - state.cache[node] = res - state.errored[node] = thunk_failed + if res isa Chunk + if !haskey(state.equiv_chunks, res) + state.equiv_chunks[res.handle::DRef] = res + end + end + store_result!(state, node, res; error=thunk_failed) if node.options !== nothing && node.options.checkpoint !== nothing try @invokelatest node.options.checkpoint(node, res) @@ -609,18 +446,20 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end end - timespan_start(ctx, :finish, (;uid=state.uid, thunk_id), (;thunk_id, result=res)) + @maybelog ctx timespan_start(ctx, :finish, (;uid=state.uid, thunk_id), (;thunk_id, result=res)) finish_task!(ctx, state, node, thunk_failed) - timespan_finish(ctx, :finish, (;uid=state.uid, thunk_id), (;thunk_id, result=res)) - - delete_unused_tasks!(state) + @maybelog ctx timespan_finish(ctx, :finish, (;uid=state.uid, thunk_id), (;thunk_id, result=res)) end + # Allow data to be GC'd + tresult = nothing + res = nothing + safepoint(state) end # Final value is ready - value = state.cache[d] + value = load_result(state, d) errored = get(state.errored, d, false) if !errored if options.checkpoint !== nothing @@ -633,10 +472,10 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end return value, errored end -function scheduler_exit(ctx, state::ComputeState, options) +function scheduler_exit(ctx, state::ComputeState, options::SchedulerOptions) @dagdebug nothing :global "Tearing down scheduler" uid=state.uid - @sync for p in procs_to_use(ctx) + @sync for p in procs_to_use(ctx, options) Threads.@spawn cleanup_proc(state, p, ctx.log_sink) end @@ -661,7 +500,7 @@ function scheduler_exit(ctx, state::ComputeState, options) @dagdebug nothing :global "Tore down scheduler" uid=state.uid end -function procs_to_use(ctx, options=ctx.options) +function procs_to_use(ctx::Context, options::SchedulerOptions) return if options.single !== nothing @assert options.single in vcat(1, workers()) "Sch option `single` must specify an active worker ID." OSProc[OSProc(options.single)] @@ -670,7 +509,7 @@ function procs_to_use(ctx, options=ctx.options) end end -check_integrity(ctx) = @assert !isempty(procs_to_use(ctx)) "No suitable workers available in context." +check_workers_available(ctx, options) = @assert !isempty(procs_to_use(ctx, options)) "No remaining workers available." struct SchedulingException <: Exception reason::String @@ -681,76 +520,88 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() -function schedule!(ctx, state, procs=procs_to_use(ctx)) +struct ScheduleTaskLocation + gproc::OSProc + proc::Processor +end +struct ScheduleTaskSpec + task::Thunk + scope::Dagger.AbstractScope + est_time_util::UInt64 + est_alloc_util::UInt64 + est_occupancy::UInt32 +end +@reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)) lock(state.lock) do safepoint(state) + @assert length(procs) > 0 # Remove processors that aren't yet initialized procs = filter(p -> haskey(state.worker_chans, Dagger.root_worker_id(p)), procs) - populate_processor_cache_list!(state, procs) - # Schedule tasks - to_fire = Dict{Tuple{OSProc,<:Processor},Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}}() - failed_scheduling = Thunk[] - + to_fire = @reusable_dict :schedule!_to_fire ScheduleTaskLocation Vector{ScheduleTaskSpec} ScheduleTaskLocation(OSProc(), OSProc()) ScheduleTaskSpec[] 1024 + to_fire_cleanup = @reuse_defer_cleanup empty!(to_fire) + failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 + failed_scheduling_cleanup = @reuse_defer_cleanup empty!(failed_scheduling) # Select a new task and get its options task = nothing @label pop_task if task !== nothing - timespan_finish(ctx, :schedule, (;uid=state.uid, thunk_id=task.id), (;thunk_id=task.id)) + @dagdebug task :schedule "Finished scheduling task" + @maybelog ctx timespan_finish(ctx, :schedule, (;uid=state.uid, thunk_id=task.id), (;thunk_id=task.id)) end if isempty(state.ready) @goto fire_tasks end - task = pop!(state.ready) - timespan_start(ctx, :schedule, (;uid=state.uid, thunk_id=task.id), (;thunk_id=task.id)) - if haskey(state.cache, task) + task = popfirst!(state.ready) + @dagdebug task :schedule "Scheduling task" + @maybelog ctx timespan_start(ctx, :schedule, (;uid=state.uid, thunk_id=task.id), (;thunk_id=task.id)) + if has_result(state, task) if haskey(state.errored, task) # An error was eagerly propagated to this task + @dagdebug task :schedule "Task received upstream error, finishing" finish_failed!(state, task) else # This shouldn't have happened + @dagdebug task :schedule "Scheduling inconsistency: Task being scheduled is already cached!" iob = IOBuffer() println(iob, "Scheduling inconsistency: Task being scheduled is already cached!") println(iob, " Task: $(task.id)") - println(iob, " Cache Entry: $(typeof(state.cache[task]))") + println(iob, " Cache Entry: $(typeof(something(task.cache_ref)))") ex = SchedulingException(String(take!(iob))) - state.cache[task] = ex - state.errored[task] = true + store_result!(state, task, ex; error=true) end @goto pop_task end - opts = merge(ctx.options, task.options) + + # Load task inputs + collect_task_inputs!(state, task) + + # Calculate signature sig = signature(state, task) + # Merge scheduler options and populate defaults + options = task.options + Dagger.options_merge!(options, sch_options) + Dagger.populate_defaults!(options, sig) + # Calculate scope - scope = constrain(task.compute_scope, task.result_scope) + scope = constrain(@something(options.compute_scope, options.scope, DefaultScope()), + @something(options.result_scope, AnyScope())) if scope isa InvalidScope ex = SchedulingException("compute_scope and result_scope are not compatible: $(scope.x), $(scope.y)") - state.cache[task] = ex - state.errored[task] = true - set_failed!(state, task) + store_result!(state, task, ex; error=true) + finish_failed!(state, task) @goto pop_task end - if task.f isa Chunk - scope = constrain(scope, task.f.scope) - if scope isa InvalidScope - ex = SchedulingException("Current scope and function Chunk Scope are not compatible: $(scope.x), $(scope.y)") - state.cache[task] = ex - state.errored[task] = true - set_failed!(state, task) - @goto pop_task - end - end - - for (_,input) in task.inputs - input = unwrap_weak_checked(input) - chunk = if istask(input) - state.cache[input] - elseif input isa Chunk - input + for arg in task.inputs + value = unwrap_weak_checked(Dagger.value(arg)) + chunk = if istask(value) + load_result(state, task) + elseif value isa Chunk + value else nothing end @@ -758,135 +609,83 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) scope = constrain(scope, chunk.scope) if scope isa InvalidScope ex = SchedulingException("Current scope and argument Chunk scope are not compatible: $(scope.x), $(scope.y)") - state.cache[task] = ex - state.errored[task] = true - set_failed!(state, task) + store_result!(state, task, ex; error=true) + finish_failed!(state, task) @goto pop_task end end - fallback_threshold = 1024 # TODO: Parameterize this threshold - if length(procs) > fallback_threshold - @goto fallback - end - local_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in procs]...)) - if length(local_procs) > fallback_threshold - @goto fallback + input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) + for proc in Dagger.compatible_processors(scope, procs) + if !(proc in input_procs) + push!(input_procs, proc) + end end - inputs = map(last, collect_task_inputs(state, task)) - opts = populate_defaults(opts, chunktype(task.f), map(chunktype, inputs)) - local_procs, costs = estimate_task_costs(state, local_procs, task, inputs) + sorted_procs = @reusable_vector :schedule!_sorted_procs Processor OSProc() 32 + sorted_procs_cleanup = @reuse_defer_cleanup empty!(sorted_procs) + resize!(sorted_procs, length(input_procs)) + costs = @reusable_dict :schedule!_costs Processor Float64 OSProc() 0.0 32 + costs_cleanup = @reuse_defer_cleanup empty!(costs) + estimate_task_costs!(sorted_procs, costs, state, input_procs, task; sig) + input_procs_cleanup() scheduled = false - # Move our corresponding ThreadProc to be the last considered - if length(local_procs) > 1 + # Move our corresponding ThreadProc to be the last considered, + # if the task is expected to run for longer than the time it takes to + # schedule it onto another worker (estimated at 1ms). + if length(sorted_procs) > 1 sch_threadproc = Dagger.ThreadProc(myid(), Threads.threadid()) - sch_thread_idx = findfirst(proc->proc==sch_threadproc, local_procs) - if sch_thread_idx !== nothing - deleteat!(local_procs, sch_thread_idx) - push!(local_procs, sch_threadproc) + sch_thread_idx = findfirst(proc->proc==sch_threadproc, sorted_procs) + if sch_thread_idx !== nothing && costs[sch_threadproc] > 1_000_000 # 1ms + deleteat!(sorted_procs, sch_thread_idx) + push!(sorted_procs, sch_threadproc) end end - for proc in local_procs + for proc in sorted_procs gproc = get_parent(proc) - can_use, scope = can_use_proc(state, task, gproc, proc, opts, scope) + can_use, scope = can_use_proc(state, task, gproc, proc, options, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) + has_capacity(state, proc, gproc.pid, options.time_util, options.alloc_util, options.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util - proc_tasks = get!(to_fire, (gproc, proc)) do - Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}() + proc_tasks = get!(to_fire, ScheduleTaskLocation(gproc, proc)) do + #=FIXME:REALLOC_VEC=# + Vector{ScheduleTaskSpec}() end - push!(proc_tasks, (task, scope, est_time_util, est_alloc_util, est_occupancy)) + push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) state.worker_time_pressure[gproc.pid][proc] = get(state.worker_time_pressure[gproc.pid], proc, 0) + est_time_util - @dagdebug task :schedule "Scheduling to $gproc -> $proc" + @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc.pid][proc]))" + sorted_procs_cleanup() + costs_cleanup() @goto pop_task end end end - state.cache[task] = SchedulingException("No processors available, try widening scope") - state.errored[task] = true - set_failed!(state, task) - @goto pop_task - - # Fast fallback algorithm, used when the smarter cost model algorithm - # would be too expensive - @label fallback - selected_entry = nothing - entry = state.procs_cache_list[] - cap, extra_util = nothing, nothing - procs_found = false - # N.B. if we only have one processor, we need to select it now - can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope) - if can_use - has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) - if has_cap - selected_entry = entry - else - procs_found = true - entry = entry.next - end - else - entry = entry.next - end - while selected_entry === nothing - if entry === state.procs_cache_list[] - # Exhausted all procs - if procs_found - push!(failed_scheduling, task) - else - state.cache[task] = SchedulingException("No processors available, try widening scope") - state.errored[task] = true - set_failed!(state, task) - end - @goto pop_task - end - can_use, scope = can_use_proc(state, task, entry.gproc, entry.proc, opts, scope) - if can_use - has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) - if has_cap - # Select this processor - selected_entry = entry - else - # We could have selected it otherwise - procs_found = true - entry = entry.next - end - else - # Try next processor - entry = entry.next - end - end - @assert selected_entry !== nothing - - # Schedule task onto proc - gproc, proc = entry.gproc, entry.proc - est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util - proc_tasks = get!(to_fire, (gproc, proc)) do - Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}() - end - push!(proc_tasks, (task, scope, est_time_util, est_alloc_util, est_occupancy)) - - # Proceed to next entry to spread work - state.procs_cache_list[] = state.procs_cache_list[].next + ex = SchedulingException("No processors available, try widening scope") + store_result!(state, task, ex; error=true) + finish_failed!(state, task) + @dagdebug task :schedule "No processors available, skipping" + sorted_procs_cleanup() + costs_cleanup() @goto pop_task # Fire all newly-scheduled tasks @label fire_tasks - for gpp in keys(to_fire) - fire_tasks!(ctx, to_fire[gpp], gpp, state) + for (task_loc, task_spec) in to_fire + fire_tasks!(ctx, task_loc, task_spec, state) end + to_fire_cleanup() append!(state.ready, failed_scheduling) + failed_scheduling_cleanup() end end @@ -894,9 +693,9 @@ end Monitors for workers being added/removed to/from `ctx`, sets up or tears down per-worker state, and notifies the scheduler so that work can be reassigned. """ -function monitor_procs_changed!(ctx, state) +function monitor_procs_changed!(ctx, state, options) # Load current set of procs - old_ps = procs_to_use(ctx) + old_ps = procs_to_use(ctx, options) while !state.halt.set # Wait for the notification that procs have changed @@ -904,20 +703,17 @@ function monitor_procs_changed!(ctx, state) wait(ctx.proc_notify) end - timespan_start(ctx, :assign_procs, (;uid=state.uid), nothing) + @maybelog ctx timespan_start(ctx, :assign_procs, (;uid=state.uid), nothing) # Load new set of procs - new_ps = procs_to_use(ctx) + new_ps = procs_to_use(ctx, options) # Initialize new procs diffps = setdiff(new_ps, old_ps) for p in diffps init_proc(state, p, ctx.log_sink) - # Empty the processor cache list and force reschedule - lock(state.lock) do - state.procs_cache_list[] = nothing - end + # Force reschedule put!(state.chan, RescheduleSignal()) end @@ -925,19 +721,14 @@ function monitor_procs_changed!(ctx, state) diffps = setdiff(old_ps, new_ps) for p in diffps cleanup_proc(state, p, ctx.log_sink) - - # Empty the processor cache list - lock(state.lock) do - state.procs_cache_list[] = nothing - end end - timespan_finish(ctx, :assign_procs, (;uid=state.uid), nothing) + @maybelog ctx timespan_finish(ctx, :assign_procs, (;uid=state.uid), nothing) old_ps = new_ps end end -function remove_dead_proc!(ctx, state, proc, options=ctx.options) +function remove_dead_proc!(ctx, state, proc, options) @assert options.single !== proc.pid "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) delete!(state.worker_time_pressure, proc.pid) @@ -945,49 +736,32 @@ function remove_dead_proc!(ctx, state, proc, options=ctx.options) delete!(state.worker_storage_capacity, proc.pid) delete!(state.worker_loadavg, proc.pid) delete!(state.worker_chans, proc.pid) - state.procs_cache_list[] = nothing end function finish_task!(ctx, state, node, thunk_failed) + @dagdebug node :finish "Finishing with $(thunk_failed ? "error" : "result")" pop!(state.running, node) delete!(state.running_on, node) if thunk_failed set_failed!(state, node) end - if node.cache - node.cache_ref = state.cache[node] - end schedule_dependents!(state, node, thunk_failed) fill_registered_futures!(state, node, thunk_failed) - to_evict = cleanup_syncdeps!(state, node) - if node.f isa Chunk - # FIXME: Check the graph for matching chunks - push!(to_evict, node.f) - end + #to_evict = cleanup_syncdeps!(state, node) + cleanup_syncdeps!(state, node) if haskey(state.waiting_data, node) && isempty(state.waiting_data[node]) delete!(state.waiting_data, node) end + if !haskey(state.waiting_data, node) + node.sch_accessible = false + delete_unused_task!(state, node) + end #evict_all_chunks!(ctx, to_evict) end -function delete_unused_tasks!(state) - to_delete = Thunk[] - for thunk in state.thunks_to_delete - if task_unused(state, thunk) - # Finished and nobody waiting on us, we can be deleted - push!(to_delete, thunk) - end - end - for thunk in to_delete - # Delete all cached data - task_delete!(state, thunk) - - pop!(state.thunks_to_delete, thunk) - end -end function delete_unused_task!(state, thunk) - if task_unused(state, thunk) + if has_result(state, thunk) && !thunk.eager_accessible && !thunk.sch_accessible # Will not be accessed further, delete all cached data task_delete!(state, thunk) return true @@ -995,18 +769,15 @@ function delete_unused_task!(state, thunk) return false end end -task_unused(state, thunk) = - haskey(state.cache, thunk) && !haskey(state.waiting_data, thunk) function task_delete!(state, thunk) - delete!(state.cache, thunk) - delete!(state.errored, thunk) + clear_result!(state, thunk) delete!(state.valid, thunk) delete!(state.thunk_dict, thunk.id) end -function evict_all_chunks!(ctx, to_evict) +function evict_all_chunks!(ctx, options, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx)) + @sync for w in map(p->p.pid, procs_to_use(ctx, options)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -1014,46 +785,47 @@ end function evict_chunks!(log_sink, chunks::Set{Chunk}) # Need worker id or else Context might use Processors which user does not want us to use. # In particular workers which have not yet run using Dagger will cause the call below to throw an exception - ctx = Context([myid()];log_sink) + ctx = Context([myid()]; log_sink) for chunk in chunks lock(TASK_SYNC) do - timespan_start(ctx, :evict, (;worker=myid()), (;data=chunk)) + @maybelog ctx timespan_start(ctx, :evict, (;worker=myid()), (;data=chunk)) haskey(CHUNK_CACHE, chunk) && delete!(CHUNK_CACHE, chunk) - timespan_finish(ctx, :evict, (;worker=myid()), (;data=chunk)) + @maybelog ctx timespan_finish(ctx, :evict, (;worker=myid()), (;data=chunk)) end end nothing end -fire_task!(ctx, thunk::Thunk, p, state; scope=AnyScope(), time_util=10^9, alloc_util=10^6, occupancy=typemax(UInt32)) = - fire_task!(ctx, (thunk, scope, time_util, alloc_util, occupancy), p, state) -fire_task!(ctx, (thunk, scope, time_util, alloc_util, occupancy)::Tuple{Thunk,<:Any}, p, state) = - fire_tasks!(ctx, [(thunk, scope, time_util, alloc_util, occupancy)], p, state) -function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) - to_send = [] - for (thunk, scope, time_util, alloc_util, occupancy) in thunks +"A serializable description of a `Thunk` to be executed." +struct TaskSpec + thunk_id::Int + est_time_util::UInt64 + est_alloc_util::UInt64 + est_occupancy::UInt32 + scope::Dagger.AbstractScope + Tf::Type + data::Vector{Argument} + options::Options + ctx_vars::NamedTuple + sch_handle::SchedulerHandle + sch_uid::UInt64 +end +Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) + +@reuse_scope function fire_tasks!(ctx, task_loc::ScheduleTaskLocation, task_specs::Vector{ScheduleTaskSpec}, state) + gproc, proc = task_loc.gproc, task_loc.proc + to_send = @reusable_vector :fire_tasks!_to_send Union{TaskSpec,Nothing} nothing 1024 + to_send_cleanup = @reuse_defer_cleanup empty!(to_send) + for task_spec in task_specs + thunk = task_spec.task push!(state.running, thunk) state.running_on[thunk] = gproc - if thunk.cache && thunk.cache_ref !== nothing - # the result might be already cached - data = thunk.cache_ref - if data !== nothing - # cache hit - state.cache[thunk] = data - thunk_failed = get(state.errored, thunk, false) - finish_task!(ctx, state, thunk, thunk_failed) - continue - else - # cache miss - thunk.cache_ref = nothing - end - end - if thunk.options !== nothing && thunk.options.restore !== nothing + @assert !has_result(state, thunk) + if thunk.options.restore !== nothing try result = @invokelatest thunk.options.restore(thunk) if result isa Chunk - state.cache[thunk] = result - state.errored[thunk] = false + store_result!(state, thunk, result) finish_task!(ctx, state, thunk, false) continue elseif result !== nothing @@ -1064,58 +836,80 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) end end - ids = Int[0] - data = Any[thunk.f] - positions = Union{Symbol,Int}[0] - arg_ctr = 1 - for (idx, pos_x) in enumerate(thunk.inputs) - pos, x = pos_x - x = unwrap_weak_checked(x) - push!(ids, istask(x) ? x.id : -idx) - push!(data, istask(x) ? state.cache[x] : x) - if pos !== nothing - # Keyword arg - push!(positions, pos) - else - # Positional arg - push!(positions, arg_ctr) - arg_ctr += 1 - end + # Duplicate options and clear un-serializable fields + options = copy(thunk.options) + options.syncdeps = nothing + + # Unwrap any weak arguments + args = map(copy, thunk.inputs) + for arg in args + # TODO: Only for non-delayed: @assert Dagger.isweak(Dagger.value(arg)) "Non-weak argument: $(arg)" + arg.value = unwrap_weak_checked(Dagger.value(arg)) end - toptions = thunk.options !== nothing ? thunk.options : ThunkOptions() - options = merge(ctx.options, toptions) - propagated = get_propagated_options(thunk) + Tf = chunktype(first(args)) + @assert (options.single === nothing) || (gproc.pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) - push!(to_send, Any[thunk.id, time_util, alloc_util, occupancy, - scope, chunktype(thunk.f), data, - thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options, - propagated, ids, positions, - (log_sink=ctx.log_sink, profile=ctx.profile), - sch_handle, state.uid, thunk.result_scope]) + push!(to_send, TaskSpec( + thunk.id, + task_spec.est_time_util, task_spec.est_alloc_util, task_spec.est_occupancy, + task_spec.scope, Tf, args, options, + (log_sink=ctx.log_sink, profile=ctx.profile), + sch_handle, state.uid)) end - # N.B. We don't batch these because we might get a deserialization - # error due to something not being defined on the worker, and then we don't - # know which task failed. - tasks = Task[] - for ts in to_send - # TODO: errormonitor - task = Threads.@spawn begin - timespan_start(ctx, :fire, (;uid=state.uid, worker=gproc.pid), nothing) - try - remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts]); - catch err - bt = catch_backtrace() - thunk_id = ts[1] - put!(state.chan, (gproc.pid, proc, thunk_id, (CapturedException(err, bt), nothing))) - finally - timespan_finish(ctx, :fire, (;uid=state.uid, worker=gproc.pid), nothing) + + if !isempty(to_send) + if Dagger.root_worker_id(gproc) == myid() + @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, to_send) + else + # N.B. We don't batch these because we might get a deserialization + # error due to something not being defined on the worker, and then we don't + # know which task failed. + for task_spec in to_send + @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, task_spec) end end end + to_send_cleanup() +end + +struct FireTaskSpec + init_proc::Processor + return_chan::RemoteChannel + tasks::Vector{TaskSpec} +end +FireTaskSpec(init_proc::Processor, return_chan::RemoteChannel, task::TaskSpec) = + FireTaskSpec(init_proc, return_chan, [task]) +function (ets::FireTaskSpec)() + tasks = ets.tasks + first_task = first(tasks) + ctx_vars = first_task.ctx_vars + ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) + uid = first_task.sch_uid + + proc = ets.init_proc + chan = ets.return_chan + pid = Dagger.root_worker_id(proc) + + @maybelog ctx timespan_start(ctx, :fire, (;uid, worker=pid), nothing) + try + if pid == myid() + do_tasks(proc, chan, tasks) + else + remotecall_wait(do_tasks, pid, proc, chan, tasks); + end + catch err + bt = catch_backtrace() + # FIXME: Catch the correct task ID + thunk_id = first_task.thunk_id + put!(chan, TaskResult(pid, proc, thunk_id, CapturedException(err, bt), nothing)) + finally + @maybelog ctx timespan_finish(ctx, :fire, (;uid, worker=pid), nothing) + end + return end @static if VERSION >= v"1.9" @@ -1178,22 +972,14 @@ function Base.notify(db::Doorbell) end end -struct TaskSpecKey - task_id::Int - task_spec::Vector{Any} - TaskSpecKey(task_spec::Vector{Any}) = new(task_spec[1], task_spec) -end -Base.getindex(key::TaskSpecKey) = key.task_spec -Base.hash(key::TaskSpecKey, h::UInt) = hash(key.task_id, hash(TaskSpecKey, h)) - struct ProcessorInternalState ctx::Context proc::Processor return_queue::RemoteChannel - queue::LockedObject{PriorityQueue{TaskSpecKey, UInt32, Base.Order.ForwardOrdering}} + queue::LockedObject{PriorityQueue{TaskSpec, UInt32, Base.Order.ForwardOrdering}} reschedule::Doorbell tasks::Dict{Int,Task} - task_specs::Dict{Int,Vector{Any}} + task_specs::Dict{Int,TaskSpec} proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} @@ -1228,9 +1014,12 @@ stealing_permitted(proc::Dagger.ThreadProc) = proc.owner != 1 || proc.tid != 1 proc_has_occupancy(proc_occupancy, task_occupancy) = UInt64(task_occupancy) + UInt64(proc_occupancy) <= typemax(UInt32) -function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, return_queue::RemoteChannel) +function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, return_queue::RemoteChannel, start_event::Base.Event) to_proc = istate.proc proc_run_task = @task begin + # Wait for our ProcessorState to be configured + wait(start_event) + # FIXME: Context changes aren't noticed over time ctx = istate.ctx tasks = istate.tasks @@ -1243,12 +1032,12 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Wait for new tasks if !work_to_do @dagdebug nothing :processor "Waiting for tasks" - timespan_start(ctx, :proc_run_wait, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :proc_run_wait, (;uid, worker=wid, processor=to_proc), nothing) wait(istate.reschedule) @static if VERSION >= v"1.9" reset(istate.reschedule) end - timespan_finish(ctx, :proc_run_wait, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :proc_run_wait, (;uid, worker=wid, processor=to_proc), nothing) if istate.done[] return end @@ -1256,7 +1045,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Fetch a new task to execute @dagdebug nothing :processor "Trying to dequeue" - timespan_start(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), nothing) work_to_do = false task_and_occupancy = lock(istate.queue) do queue # Only steal if there are multiple queued tasks, to prevent @@ -1275,7 +1064,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re return queue_result end if task_and_occupancy === nothing - timespan_finish(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), nothing) @dagdebug nothing :processor "Failed to dequeue" @@ -1290,7 +1079,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re @dagdebug nothing :processor "Trying to steal" # Try to steal a task - timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors @@ -1310,9 +1099,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re if length(queue) == 0 return nothing end - task_spec, occupancy = peek(queue) - task = task_spec[] - scope = task[5] + task, occupancy = peek(queue) + scope = task.scope if !isa(constrain(scope, Dagger.ExactScope(to_proc)), InvalidScope) && typemax(UInt32) - proc_occupancy_cached >= occupancy @@ -1323,14 +1111,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end if task_and_occupancy !== nothing from_proc = other_istate.proc - thunk_id = task[1] + thunk_id = task.thunk_id @dagdebug thunk_id :processor "Stolen from $from_proc by $to_proc" - timespan_finish(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), (;from_proc, thunk_id)) + @maybelog ctx timespan_finish(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), (;from_proc, thunk_id)) # TODO: Keep stealing until we hit full occupancy? @goto execute end end - timespan_finish(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) # TODO: Try to steal from remote queues @@ -1338,73 +1126,34 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end @label execute - task_spec, task_occupancy = task_and_occupancy - task = task_spec[] - thunk_id = task[1] - time_util = task[2] - timespan_finish(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), (;thunk_id, proc_occupancy=proc_occupancy[], task_occupancy)) + task, task_occupancy = task_and_occupancy + thunk_id = task.thunk_id + time_util = task.est_time_util + @maybelog ctx timespan_finish(ctx, :proc_run_fetch, (;uid, worker=wid, processor=to_proc), (;thunk_id, proc_occupancy=proc_occupancy[], task_occupancy)) @dagdebug thunk_id :processor "Dequeued task" - # Execute the task and return its result - t = @task begin - # Set up cancellation - cancel_token = Dagger.CancelToken() - Dagger.DTASK_CANCEL_TOKEN[] = cancel_token - lock(istate.queue) do _ - istate.cancel_tokens[thunk_id] = cancel_token - end - was_cancelled = false - - result = try - do_task(to_proc, task) - catch err - bt = catch_backtrace() - (CapturedException(err, bt), nothing) - finally - lock(istate.queue) do _ - delete!(tasks, thunk_id) - delete!(istate.task_specs, thunk_id) - if !(thunk_id in istate.cancelled) - proc_occupancy[] -= task_occupancy - time_pressure[] -= time_util - else - # Task was cancelled, so occupancy and pressure are - # already reduced - pop!(istate.cancelled, thunk_id) - delete!(istate.cancel_tokens, thunk_id) - was_cancelled = true - end - end - notify(istate.reschedule) - end - if was_cancelled - # A result was already posted to the return queue - return - end - try - put!(return_queue, (myid(), to_proc, thunk_id, result)) - catch err - if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) - @dagdebug thunk_id :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) - else - rethrow() - end - finally - # Ensure that any spawned tasks get cleaned up - Dagger.cancel!(cancel_token) - end - end + # Set up cancellation and update task accounting + cancel_token = Dagger.CancelToken() lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + proc_occupancy[] += task_occupancy + time_pressure[] += time_util + end + + # Launch the task + t = @reusable_tasks :start_processor_runner!_task_cache 32 t->begin tid = task_tid_for_processor(to_proc) if tid !== nothing Dagger.set_task_tid!(t, tid) else t.sticky = false end - tasks[thunk_id] = errormonitor_tracked("thunk $thunk_id", schedule(t)) + end "thunk $thunk_id" DoTaskSpec(to_proc, return_queue, task, cancel_token) + + # Update task accounting + lock(istate.queue) do _ + tasks[thunk_id] = t istate.task_specs[thunk_id] = task - proc_occupancy[] += task_occupancy - time_pressure[] += time_util end end end @@ -1416,6 +1165,80 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end return errormonitor_tracked("processor $to_proc", schedule(proc_run_task)) end +struct DoTaskSpec + to_proc::Processor + chan::RemoteChannel + task::TaskSpec + cancel_token::Dagger.CancelToken +end +function (dts::DoTaskSpec)() + to_proc = dts.to_proc + task = dts.task + tid = task.thunk_id + Dagger.DTASK_CANCEL_TOKEN[] = dts.cancel_token + + # Execute the task and return its result + was_cancelled = false + result, metadata = try + do_task(to_proc, task) + catch err + bt = catch_backtrace() + (CapturedException(err, bt), nothing) + finally + istate = proc_states(task.sch_uid) do states + if haskey(states, to_proc) + return states[to_proc].state + end + # Processor was removed due to scheduler exit + return nothing + end + if istate !== nothing + while true + # Wait until the task has been recorded in the processor state + done = lock(istate.queue) do _ + if haskey(istate.tasks, tid) + delete!(istate.tasks, tid) + delete!(istate.task_specs, tid) + if !(tid in istate.cancelled) + istate.proc_occupancy[] -= task.est_occupancy + istate.time_pressure[] -= task.est_time_util + else + # Task was cancelled, so occupancy and pressure are + # already reduced + pop!(istate.cancelled, tid) + delete!(istate.cancel_tokens, tid) + was_cancelled = true + end + return true + end + return false + end + done && break + sleep(0.1) + end + notify(istate.reschedule) + end + + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(dts.cancel_token) + end + if was_cancelled + # A result was already posted to the return queue + return + end + + return_queue = dts.chan + try + put!(return_queue, TaskResult(myid(), to_proc, tid, result, metadata)) + catch err + if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) + @dagdebug tid :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) + else + rethrow() + end + end + return +end """ do_tasks(to_proc, return_queue, tasks) @@ -1426,36 +1249,41 @@ Executes a batch of tasks on `to_proc`, returning their results through function do_tasks(to_proc, return_queue, tasks) @dagdebug nothing :processor "Enqueuing task batch" batch_size=length(tasks) - # FIXME: This is terrible - ctx_vars = first(tasks)[16] + ctx_vars = first(tasks).ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - uid = first(tasks)[18] + uid = first(tasks).sch_uid state = proc_states(uid) do states - get!(states, to_proc) do - queue = PriorityQueue{TaskSpecKey, UInt32}() - queue_locked = LockedObject(queue) - reschedule = Doorbell() - istate = ProcessorInternalState(ctx, to_proc, return_queue, - queue_locked, reschedule, - Dict{Int,Task}(), - Dict{Int,Vector{Any}}(), - Ref(UInt32(0)), Ref(UInt64(0)), - Set{Int}(), - Dict{Int,Dagger.CancelToken}(), - Ref(false)) - runner = start_processor_runner!(istate, uid, return_queue) - @static if VERSION < v"1.9" - reschedule.waiter = runner - end - return ProcessorState(istate, runner) + if haskey(states, to_proc) + return states[to_proc] + end + + # Initialize the processor state and runner + queue = PriorityQueue{TaskSpec, UInt32}() + queue_locked = LockedObject(queue) + reschedule = Doorbell() + istate = ProcessorInternalState(ctx, to_proc, return_queue, + queue_locked, reschedule, + Dict{Int,Task}(), + Dict{Int,Vector{Any}}(), + Ref(UInt32(0)), Ref(UInt64(0)), + Set{Int}(), + Dict{Int,Dagger.CancelToken}(), + Ref(false)) + start_event = Base.Event() + runner = start_processor_runner!(istate, uid, return_queue, start_event) + @static if VERSION < v"1.9" + reschedule.waiter = runner end + state = states[to_proc] = ProcessorState(istate, runner) + notify(start_event) + return state end istate = state.state lock(istate.queue) do queue for task in tasks - thunk_id = task[1] - occupancy = task[4] - timespan_start(ctx, :enqueue, (;uid, processor=to_proc, thunk_id), nothing) + thunk_id = task.thunk_id + occupancy = task.est_occupancy + @maybelog ctx timespan_start(ctx, :enqueue, (;uid, processor=to_proc, thunk_id), nothing) should_launch = lock(TASK_SYNC) do # Already running; don't try to re-launch if !(thunk_id in TASKS_RUNNING) @@ -1466,8 +1294,8 @@ function do_tasks(to_proc, return_queue, tasks) end end should_launch || continue - enqueue!(queue, TaskSpecKey(task), occupancy) - timespan_finish(ctx, :enqueue, (;uid, processor=to_proc, thunk_id), nothing) + enqueue!(queue, task, occupancy) + @maybelog ctx timespan_finish(ctx, :enqueue, (;uid, processor=to_proc, thunk_id), nothing) @dagdebug thunk_id :processor "Enqueued task" end end @@ -1488,42 +1316,42 @@ function do_tasks(to_proc, return_queue, tasks) end """ - do_task(to_proc, task_desc) -> Any + do_task(to_proc, task::TaskSpec) -> Any -Executes a single task specified by `task_desc` on `to_proc`. +Executes a single task specified by `task` on `to_proc`. """ -function do_task(to_proc, task_desc) - thunk_id, est_time_util, est_alloc_util, est_occupancy, - scope, Tf, data, - send_result, persist, cache, meta, - options, propagated, ids, positions, - ctx_vars, sch_handle, sch_uid, result_scope = task_desc +@reuse_scope function do_task(to_proc, task::TaskSpec) + thunk_id = task.thunk_id + + ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) from_proc = OSProc() - Tdata = Any[] - for x in data - push!(Tdata, chunktype(x)) - end + data = task.data + Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available + options = task.options + propagated = get_propagated_options(options) to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] - to_storage_name = nameof(typeof(to_storage)) - storage_cap = storage_capacity(to_storage) + #to_storage_name = nameof(typeof(to_storage)) + #storage_cap = storage_capacity(to_storage) - timespan_start(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + est_time_util = task.est_time_util + est_alloc_util = task.est_alloc_util real_time_util = Ref{UInt64}(0) real_alloc_util = UInt64(0) + #= FIXME: Serialize on over-memory situation + @maybelog ctx timespan_start(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) if !meta # Factor in the memory costs for our lazy arguments for arg in data[2:end] - if arg isa Chunk - est_alloc_util += arg.handle.size + if Dagger.valuetype(arg) <: Chunk + est_alloc_util += Dagger.value(arg).handle.size end end end - debug_storage(msg::String) = @debug begin let est_alloc_util=Base.format_bytes(est_alloc_util), real_alloc_util=Base.format_bytes(real_alloc_util), @@ -1534,7 +1362,7 @@ function do_task(to_proc, task_desc) lock(TASK_SYNC) do while true # Get current time utilization for the selected processor - time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, sch_uid) + time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, task.sch_uid) real_time_util = get!(()->Ref{UInt64}(UInt64(0)), time_dict, to_proc) # Get current allocation utilization and capacity @@ -1543,48 +1371,51 @@ function do_task(to_proc, task_desc) # Check if we'll go over memory capacity from running this thunk # Waits for free storage, if necessary - #= TODO: Implement a priority queue, ordered by est_alloc_util - if est_alloc_util > storage_cap - debug_storage("WARN: Estimated utilization above storage capacity on $to_storage_name, proceeding anyway") - break - end - if est_alloc_util + real_alloc_util > storage_cap - if MemPool.externally_varying(to_storage) - debug_storage("WARN: Insufficient space and allocation behavior is externally varying on $to_storage_name, proceeding anyway") - break - end - if length(TASKS_RUNNING) <= 2 # This task + eager submission task - debug_storage("WARN: Insufficient space and no other running tasks on $to_storage_name, proceeding anyway") - break - end - # Fully utilized, wait and re-check - debug_storage("Waiting for free $to_storage_name") - wait(TASK_SYNC) - else - # Sufficient free storage is available, prepare for execution - debug_storage("Using available $to_storage_name") - break - end - =# + # TODO: Implement a priority queue, ordered by est_alloc_util + #if est_alloc_util > storage_cap + # debug_storage("WARN: Estimated utilization above storage capacity on $to_storage_name, proceeding anyway") + # break + #end + #if est_alloc_util + real_alloc_util > storage_cap + # if MemPool.externally_varying(to_storage) + # debug_storage("WARN: Insufficient space and allocation behavior is externally varying on $to_storage_name, proceeding anyway") + # break + # end + # if length(TASKS_RUNNING) <= 2 # This task + eager submission task + # debug_storage("WARN: Insufficient space and no other running tasks on $to_storage_name, proceeding anyway") + # break + # end + # # Fully utilized, wait and re-check + # debug_storage("Waiting for free $to_storage_name") + # wait(TASK_SYNC) + #else + # # Sufficient free storage is available, prepare for execution + # debug_storage("Using available $to_storage_name") + # break + #end # FIXME break end end - timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + =# @dagdebug thunk_id :execute "Moving data" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) transfer_size = Threads.Atomic{UInt64}(0) - _data, _ids, _positions = if meta - (Any[first(data)], Int[first(ids)], Union{Symbol,Int}[first(positions)]) # always fetch function + _data = if something(options.meta, false) + Argument[first(data)] # always fetch function else - (data, ids, positions) + data end - fetch_tasks = map(Iterators.zip(_data, _ids, _positions)) do (x, id, position) + fetch_tasks = map(_data) do arg + #=FIXME:REALLOC_TASKS=# Threads.@spawn begin - timespan_start(ctx, :move, (;thunk_id, id, position, processor=to_proc), (;f, data=x)) + value = Dagger.value(arg) + position = arg.pos + @maybelog ctx timespan_start(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=value)) #= FIXME: This isn't valid if x is written to x = if x isa Chunk value = lock(TASK_SYNC) do @@ -1627,33 +1458,32 @@ function do_task(to_proc, task_desc) end else =# - new_x = @invokelatest move(to_proc, x) + new_value = @invokelatest move(to_proc, value) #end - if new_x !== x - @dagdebug thunk_id :move "Moved argument $position to $to_proc: $(typeof(x)) -> $(typeof(new_x))" + if new_value !== value + @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" end - timespan_finish(ctx, :move, (;thunk_id, id, position, processor=to_proc), (;f, data=new_x); tasks=[Base.current_task()]) - return new_x + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) + arg.value = new_value + return end end - fetched = Any[] for task in fetch_tasks - push!(fetched, fetch_report(task)) - end - if meta - append!(fetched, data[2:end]) + fetch_report(task) end - f = popfirst!(fetched) + f = Dagger.value(first(data)) @assert !(f isa Chunk) "Failed to unwrap thunk function" - fetched_args = Any[] - fetched_kwargs = Pair{Symbol,Any}[] - for (idx, x) in enumerate(fetched) - pos = positions[idx+1] - if pos isa Int - push!(fetched_args, x) + fetched_args = @reusable_vector :do_task_fetched_args Any nothing 32 + fetched_args_cleanup = @reuse_defer_cleanup empty!(fetched_args) + fetched_kwargs = @reusable_vector :do_task_fetched_kwargs Pair{Symbol,Any} :NULL=>nothing 32 + fetched_kwargs_cleanup = @reuse_defer_cleanup empty!(fetched_kwargs) + for idx in 2:length(data) + arg = data[idx] + if Dagger.ispositional(arg) + push!(fetched_args, Dagger.value(arg)) else - push!(fetched_kwargs, pos => x) + push!(fetched_kwargs, Dagger.pos_kw(arg) => Dagger.value(arg)) end end @@ -1666,8 +1496,7 @@ function do_task(to_proc, task_desc) =# real_time_util[] += est_time_util - timespan_start(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) - res = nothing + @maybelog ctx timespan_start(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) # Start counting time and GC allocations threadtime_start = cputhreadtime() @@ -1676,48 +1505,61 @@ function do_task(to_proc, task_desc) @dagdebug thunk_id :execute "Executing $(typeof(f))" + logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) + result_meta = try # Set TLS variables - Dagger.set_tls!(( - sch_uid, - sch_handle, + Dagger.set_tls!((; + sch_uid=task.sch_uid, + sch_handle=task.sch_handle, processor=to_proc, - task_spec=task_desc, + task_spec=task, cancel_token=Dagger.DTASK_CANCEL_TOKEN[], + logging_enabled, )) - res = Dagger.with_options(propagated) do + result = Dagger.with_options(propagated) do # Execute execute!(to_proc, f, fetched_args...; fetched_kwargs...) end # Check if result is safe to store + # FIXME: Move here and below *after* timespan_finish for :compute device = nothing - if !(res isa Chunk) - timespan_start(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(res))) - device = if walk_storage_safe(res) + if !(result isa Chunk) + @maybelog ctx timespan_start(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(result))) + device = if walk_storage_safe(result) to_storage else MemPool.CPURAMDevice() end - timespan_finish(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(res))) + @maybelog ctx timespan_finish(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(result))) end # Construct result - # TODO: We should cache this locally - send_result || meta ? res : tochunk(res, to_proc, result_scope; device, persist, cache=persist ? true : cache, - tag=options.storage_root_tag, - leaf_tag=something(options.storage_leaf_tag, MemPool.Tag()), - retain=options.storage_retain) + result_meta = if something(options.get_result, false) || something(options.meta, false) + result + else + # TODO: Cache this Chunk locally in CHUNK_CACHE right now + tochunk(result, to_proc, @something(options.result_scope, AnyScope()); + device, + tag=options.storage_root_tag, + leaf_tag=something(options.storage_leaf_tag, MemPool.Tag()), + retain=something(options.storage_retain, false)) + end catch ex bt = catch_backtrace() RemoteException(myid(), CapturedException(ex, bt)) + finally + fetched_args_cleanup() + fetched_kwargs_cleanup() end threadtime = cputhreadtime() - threadtime_start # FIXME: This is not a realistic measure of max. required memory #gc_allocd = min(max(UInt64(Base.gc_num().allocd) - UInt64(gcnum_start.allocd), UInt64(0)), UInt64(1024^4)) - timespan_finish(ctx, :compute, (;thunk_id, processor=to_proc), (;f, result=result_meta)) + @maybelog ctx timespan_finish(ctx, :compute, (;thunk_id, processor=to_proc), (;f, result=result_meta)) + lock(TASK_SYNC) do real_time_util[] -= est_time_util pop!(TASKS_RUNNING, thunk_id) @@ -1729,12 +1571,13 @@ function do_task(to_proc, task_desc) # TODO: debug_storage("Releasing $to_storage_name") metadata = ( time_pressure=real_time_util[], - storage_pressure=real_alloc_util, - storage_capacity=storage_cap, - loadavg=((Sys.loadavg()...,) ./ Sys.CPU_THREADS), + #storage_pressure=real_alloc_util, + #storage_capacity=storage_cap, + #loadavg=((Sys.loadavg()...,) ./ Sys.CPU_THREADS), threadtime=threadtime, # FIXME: Add runtime allocation tracking - gc_allocd=(isa(result_meta, Chunk) ? result_meta.handle.size : 0), + #gc_allocd=(isa(result_meta, Chunk) ? result_meta.handle.size : 0), + gc_allocd=0, transfer_rate=(transfer_size[] > 0 && transfer_time[] > 0) ? round(UInt64, transfer_size[] / (transfer_time[] / 10^9)) : nothing, ) return (result_meta, metadata) diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index 5b917fdb5..0b972bdf1 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -96,13 +96,18 @@ function dynamic_listener!(ctx, state, wid) end end end + return end errormonitor_tracked("dynamic_listener! $wid", listener_task) errormonitor_tracked("dynamic_listener! (halt+throw) $wid", Threads.@spawn begin wait(state.halt) # TODO: Not sure why we need the Threads.@spawn here, but otherwise we # don't stop all the listener tasks - Threads.@spawn Base.throwto(listener_task, SchedulerHaltedException()) + Threads.@spawn begin + Base.throwto(listener_task, SchedulerHaltedException()) + return + end + return end) end @@ -124,13 +129,13 @@ end halt!(h::SchedulerHandle) = exec!(_halt, h, nothing) function _halt(ctx, state, task, tid, _) notify(state.halt) - put!(state.chan, (1, nothing, nothing, (SchedulerHaltedException(), nothing))) + put!(state.chan, TaskResult(1, OSProc(), 0, SchedulerHaltedException(), nothing)) Base.throwto(task, SchedulerHaltedException()) end "Waits on a thunk to complete, and fetches its result." function Base.fetch(h::SchedulerHandle, id::ThunkID) - future = ThunkFuture(Future(1)) + future = ThunkFuture() exec!(_register_future!, h, future, id, true) fetch(future; proc=task_processor()) end @@ -172,8 +177,8 @@ function _register_future!(ctx, state, task, tid, (future, id, check)::Tuple{Thu end end # TODO: Assert that future will be fulfilled - if haskey(state.cache, thunk) - put!(future, state.cache[thunk]; error=state.errored[thunk]) + if has_result(state, thunk) + put!(future, load_result(state, thunk); error=state.errored[thunk]) else futures = get!(()->ThunkFuture[], state.futures, thunk) push!(futures, future) @@ -184,7 +189,7 @@ end # TODO: Optimize wait() to not serialize a Chunk "Waits on a thunk to complete." -function Base.wait(h::SchedulerHandle, id::ThunkID; future=ThunkFuture(1)) +function Base.wait(h::SchedulerHandle, id::ThunkID; future=ThunkFuture()) register_future!(h, id, future) wait(future) end @@ -208,37 +213,28 @@ function _get_dag_ids(ctx, state, task, tid, _) end "Adds a new Thunk to the DAG." -add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, options...) = - exec!(_add_thunk!, h, f, args, options, future, ref) -function _add_thunk!(ctx, state, task, tid, (f, args, options, future, ref)) - timespan_start(ctx, :add_thunk, (;thunk_id=tid), (;f, args, options)) - _args = map(args) do pos_arg - if pos_arg[2] isa ThunkID - return pos_arg[1] => state.thunk_dict[pos_arg[2].id] - else - return pos_arg[1] => pos_arg[2] - end +function add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, options...) + if ref !== nothing + @warn "`ref` is no longer supported in `add_thunk!`" maxlog=1 end - GC.@preserve _args begin - thunk = Thunk(f, _args...; options...) - # Create a `DRef` to `thunk` so that the caller can preserve it - thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice()) - thunk_id = ThunkID(thunk.id, thunk_ref) - state.thunk_dict[thunk.id] = WeakThunk(thunk) - reschedule_syncdeps!(state, thunk) - @dagdebug thunk :submit "Added to scheduler" - if future !== nothing - # Ensure we attach a future before the thunk is scheduled - _register_future!(ctx, state, task, tid, (future, thunk_id, false)) - @dagdebug thunk :submit "Registered future" - end - if ref !== nothing - # Preserve the `DTaskFinalizer` through `thunk` - thunk.eager_ref = ref + return exec!(_add_thunk!, h, f, args, options, future) +end +function _add_thunk!(ctx, state, task, tid, (f, args, options, future)) + if future === nothing + future = ThunkFuture() + end + _options = Dagger.Options(;options...) + fargs = Dagger.Argument[] + push!(fargs, Dagger.Argument(Dagger.ArgPosition(true, 0, :NULL), f)) + pos_idx = 1 + for (pos, arg) in args + if pos === nothing + push!(fargs, Dagger.Argument(pos_idx, arg)) + pos_idx += 1 + else + push!(fargs, Dagger.Argument(pos, arg)) end - state.valid[thunk] = nothing - put!(state.chan, RescheduleSignal()) - timespan_finish(ctx, :add_thunk, (;thunk_id=tid), (;f, args, options)) - return thunk_id end + payload = Dagger.PayloadOne(UInt(0), future, fargs, _options, true) + return Dagger.eager_submit_internal!(ctx, state, task, tid, payload) end diff --git a/src/sch/eager.jl b/src/sch/eager.jl index aea0abbf6..bd703f1ee 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -23,10 +23,15 @@ function init_eager() return end ctx = eager_context() - errormonitor_tracked("eager compute()", Threads.@spawn try + # N.B. We use @async here to prevent the scheduler task from running on a + # different thread than the one that is likely submitting work, as otherwise + # the scheduler task might sleep while holding the scheduler lock and + # prevent work submission until it wakes up. Further testing is needed. + errormonitor_tracked("eager compute()", @async try sopts = SchedulerOptions(;allow_errors=true) opts = Dagger.Options((;scope=Dagger.ExactScope(Dagger.ThreadProc(1, 1)), - occupancy=Dict(Dagger.ThreadProc=>0))) + occupancy=Dict(Dagger.ThreadProc=>0), + time_util=Dict(Dagger.ThreadProc=>0))) Dagger.compute(ctx, Dagger._delayed(eager_thunk, opts)(); options=sopts) catch err @@ -76,7 +81,7 @@ function thunk_yield(f) proc_istate = proc_states(tls.sch_uid) do states states[proc].state end - task_occupancy = tls.task_spec[4] + task_occupancy = tls.task_spec.est_occupancy # Decrease our occupancy and inform the processor to reschedule lock(proc_istate.queue) do _ @@ -108,31 +113,6 @@ function thunk_yield(f) end end -eager_cleanup(t::Dagger.DTaskFinalizer) = - errormonitor_tracked("eager_cleanup $(t.uid)", Threads.@spawn eager_cleanup(EAGER_STATE[], t.uid)) -function eager_cleanup(state, uid) - tid = nothing - lock(EAGER_ID_MAP) do id_map - if !haskey(id_map, uid) - return - end - tid = id_map[uid] - delete!(id_map, uid) - end - tid === nothing && return - lock(state.lock) do - # N.B. cache and errored expire automatically - delete!(state.thunk_dict, tid) - end - remotecall_wait(1, uid) do uid - lock(Dagger.EAGER_THUNK_STREAMS) do global_streams - if haskey(global_streams, uid) - delete!(global_streams, uid) - end - end - end -end - function _find_thunk(e::Dagger.DTask) tid = lock(EAGER_ID_MAP) do id_map id_map[e.uid] diff --git a/src/sch/fault-handler.jl b/src/sch/fault-handler.jl index fca184cfa..56ccc3ca1 100644 --- a/src/sch/fault-handler.jl +++ b/src/sch/fault-handler.jl @@ -20,11 +20,13 @@ function handle_fault(ctx, state, deadproc) deadlist = Thunk[] # Evict cache entries that were stored on the worker - for t in keys(state.cache) - v = state.cache[t] + for t in values(state.thunk_dict) + t = unwrap_weak_checked(t) + has_result(state, t) || continue + v = load_result(state, t) if v isa Chunk && v.handle isa DRef && v.handle.owner == deadproc.pid push!(deadlist, t) - pop!(state.cache, t) + clear_result!(state, t) end end # Remove thunks that were running on the worker diff --git a/src/sch/util.jl b/src/sch/util.jl index dd148d336..1d947c23a 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -18,6 +18,17 @@ function errormonitor_tracked(name::String, t::Task) end end) end +function errormonitor_tracked_set!(name::String, t::Task) + lock(ERRORMONITOR_TRACKED) do tracked + for idx in 1:length(tracked) + if tracked[idx][2] === t + tracked[idx] = name => t + return + end + end + error("Task not found in tracked list") + end +end const ERRORMONITOR_TRACKED = LockedObject(Pair{String,Task}[]) """ @@ -37,72 +48,108 @@ unwrap_nested_exception(err::LoadError) = unwrap_nested_exception(err.error) unwrap_nested_exception(err) = err -"Gets a `NamedTuple` of options propagated by `thunk`." -function get_propagated_options(thunk) +"Gets a `NamedTuple` of options propagated from `options`." +function get_propagated_options(options::Options) + # FIXME: Just use an Options as output? nt = NamedTuple() - for key in thunk.propagates - value = if key == :scope - thunk.compute_scope - elseif key in fieldnames(Thunk) - getproperty(thunk, key) - elseif key in fieldnames(ThunkOptions) - getproperty(thunk.options, key) + if options.propagates === nothing + return nt + end + for key in options.propagates + value = if hasfield(Options, key) + getfield(options, key) else throw(ArgumentError("Can't propagate unknown key: $key")) end nt = merge(nt, (key=>value,)) end - nt + return nt end -"Fills the result for all registered futures of `node`." -function fill_registered_futures!(state, node, failed) - if haskey(state.futures, node) +has_result(state, thunk) = thunk.cache_ref !== nothing +function load_result(state, thunk) + @assert thunk.finished "Thunk[$(thunk.id)] is not yet finished" + return something(thunk.cache_ref) +end +function store_result!(state, thunk, value; error::Bool=false) + @assert islocked(state.lock) + @assert !thunk.finished "Thunk[$(thunk.id)] should not be finished yet" + @assert !has_result(state, thunk) "Thunk[$(thunk.id)] already contains a cached result" + thunk.finished = true + if error && value isa Exception && !(value isa DTaskFailedException) + thunk.cache_ref = Some{Any}(DTaskFailedException(thunk, thunk, value)) + else + thunk.cache_ref = Some{Any}(value) + end + state.errored[thunk] = error +end +function clear_result!(state, thunk) + @assert islocked(state.lock) + thunk.cache_ref = nothing + delete!(state.errored, thunk) +end + +"Fills the result for all registered futures of `thunk`." +function fill_registered_futures!(state, thunk, failed) + if haskey(state.futures, thunk) # Notify any listening thunks - for future in state.futures[node] - put!(future, state.cache[node]; error=failed) + @dagdebug thunk :finish "Notifying $(length(state.futures[thunk])) futures" + for future in state.futures[thunk] + put!(future, load_result(state, thunk); error=failed) end - delete!(state.futures, node) + delete!(state.futures, thunk) end end "Cleans up any syncdeps that aren't needed any longer, and returns a `Set{Chunk}` of all chunks that can now be evicted from workers." -function cleanup_syncdeps!(state, node) - to_evict = Set{Chunk}() - for inp in node.syncdeps +function cleanup_syncdeps!(state, thunk) + #to_evict = Set{Chunk}() + thunk.options.syncdeps === nothing && return + for inp in thunk.options.syncdeps inp = unwrap_weak_checked(inp) - if !istask(inp) && !(inp isa Chunk) - continue - end + @assert istask(inp) if inp in keys(state.waiting_data) w = state.waiting_data[inp] - if node in w - pop!(w, node) + if thunk in w + pop!(w, thunk) end if isempty(w) - if istask(inp) && haskey(state.cache, inp) - _node = state.cache[inp] - if _node isa Chunk - push!(to_evict, _node) + #= FIXME: Worker-side cache is currently disabled + if istask(inp) && has_result(state, inp) + _thunk = load_result(state, inp) + if _thunk isa Chunk + push!(to_evict, _thunk) end elseif inp isa Chunk push!(to_evict, inp) end + =# delete!(state.waiting_data, inp) + inp.sch_accessible = false + delete_unused_task!(state, inp) end end end - return to_evict + #return to_evict end "Schedules any dependents that may be ready to execute." -function schedule_dependents!(state, node, failed) - for dep in sort!(collect(get(()->Set{Thunk}(), state.waiting_data, node)), by=state.node_order) +function schedule_dependents!(state, thunk, failed) + @dagdebug thunk :finish "Checking dependents" + if !haskey(state.waiting_data, thunk) || isempty(state.waiting_data[thunk]) + return + end + ctr = 0 + for dep in state.waiting_data[thunk] + @dagdebug dep :schedule "Checking dependent" dep_isready = false if haskey(state.waiting, dep) set = state.waiting[dep] - node in set && pop!(set, node) + thunk in set && pop!(set, thunk) + if length(set) > 0 + @dagdebug dep :schedule "Dependent has $(length(set)) upstreams" + end dep_isready = isempty(set) if dep_isready delete!(state.waiting, dep) @@ -111,83 +158,103 @@ function schedule_dependents!(state, node, failed) dep_isready = true end if dep_isready + ctr += 1 if !failed push!(state.ready, dep) + @dagdebug dep :schedule "Dependent is now ready" + else + set_failed!(state, thunk, dep) + @dagdebug dep :schedule "Dependent has transitively failed" end end end + @dagdebug thunk :finish "Marked $ctr dependents as $(failed ? "failed" : "ready")" end """ Prepares the scheduler to schedule `thunk`. Will mark `thunk` as ready if its inputs are satisfied. """ -function reschedule_syncdeps!(state, thunk, seen=Set{Thunk}()) - to_visit = Thunk[thunk] - while !isempty(to_visit) - thunk = pop!(to_visit) - push!(seen, thunk) - if haskey(state.valid, thunk) - continue - end - if haskey(state.cache, thunk) || (thunk in state.ready) || (thunk in state.running) - continue - end - for (_,input) in thunk.inputs - if input isa WeakChunk - input = unwrap_weak_checked(input) +function reschedule_syncdeps!(state, thunk, seen=nothing) + Dagger.maybe_take_or_alloc!(RESCHEDULE_SYNCDEPS_SEEN_CACHE[], seen) do seen + #=FIXME:REALLOC=# + to_visit = Thunk[thunk] + while !isempty(to_visit) + thunk = pop!(to_visit) + push!(seen, thunk) + if haskey(state.valid, thunk) + continue end - if input isa Chunk - # N.B. Different Chunks with the same DRef handle will hash to the same slot, - # so we just pick an equivalent Chunk as our upstream - if !haskey(state.waiting_data, input) - push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + if thunk.finished || (thunk in state.ready) || (thunk in state.running) + continue + end + for idx in 1:length(thunk.inputs) + input = Dagger.value(thunk.inputs[idx]) + if input isa WeakChunk + input = unwrap_weak_checked(input) + end + if input isa Chunk + # N.B. Different Chunks with the same DRef handle will hash to the same slot, + # so we just pick an equivalent Chunk as our upstream + if !haskey(state.waiting_data, input) + push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + end end end - end - w = get!(()->Set{Thunk}(), state.waiting, thunk) - for input in thunk.syncdeps - input = unwrap_weak_checked(input) - istask(input) && input in seen && continue + w = get!(()->Set{Thunk}(), state.waiting, thunk) + if thunk.options.syncdeps !== nothing + for input in thunk.options.syncdeps + input = unwrap_weak_checked(input) + istask(input) && input in seen && continue - # Unseen - push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) - istask(input) || continue + # Unseen + push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + istask(input) || continue - # Unseen task - if get(state.errored, input, false) - set_failed!(state, input, thunk) - end - haskey(state.cache, input) && continue + # Unseen task + if get(state.errored, input, false) + set_failed!(state, input, thunk) + end + input.finished && continue - # Unseen and unfinished task - push!(w, input) - if !((input in state.running) || (input in state.ready)) - push!(to_visit, input) + # Unseen and unfinished task + push!(w, input) + if !((input in state.running) || (input in state.ready)) + push!(to_visit, input) + end + end end - end - if isempty(w) - # Inputs are ready - delete!(state.waiting, thunk) - if !get(state.errored, thunk, false) - push!(state.ready, thunk) + if isempty(w) + # Inputs are ready + delete!(state.waiting, thunk) + if !get(state.errored, thunk, false) + push!(state.ready, thunk) + end end end end end +const RESCHEDULE_SYNCDEPS_SEEN_CACHE = TaskLocalValue{ReusableCache{Set{Thunk},Nothing}}(()->ReusableCache(Set{Thunk}, nothing, 1)) "Marks `thunk` and all dependent thunks as failed." function set_failed!(state, origin, thunk=origin) - filter!(x->x!==thunk, state.ready) - ex = state.cache[origin] - if ex isa RemoteException - ex = ex.captured + @assert islocked(state.lock) + has_result(state, thunk) && return + @dagdebug thunk :finish "Setting as failed" + filter!(x -> x !== thunk, state.ready) + # N.B. If origin === thunk, we assume that the caller has already set the error + if origin !== thunk + origin_ex = load_result(state, origin) + if origin_ex isa RemoteException + origin_ex = origin_ex.captured + end + ex = DTaskFailedException(thunk, origin, origin_ex) + store_result!(state, thunk, ex; error=true) end - state.cache[thunk] = DTaskFailedException(thunk, origin, ex) - state.errored[thunk] = true finish_failed!(state, thunk, origin) end function finish_failed!(state, thunk, origin=nothing) + @assert islocked(state.lock) fill_registered_futures!(state, thunk, true) if haskey(state.waiting_data, thunk) for dep in state.waiting_data[thunk] @@ -198,6 +265,8 @@ function finish_failed!(state, thunk, origin=nothing) origin !== nothing && set_failed!(state, origin, dep) end delete!(state.waiting_data, thunk) + thunk.sch_accessible = false + delete_unused_task!(state, thunk) end if haskey(state.waiting, thunk) delete!(state.waiting, thunk) @@ -221,7 +290,7 @@ function print_sch_status(io::IO, state, thunk; offset=0, limit=5, max_inputs=3) status *= "r" elseif node in state.running status *= "R" - elseif haskey(state.cache, node) + elseif has_result(state, node) status *= "C" else status *= "?" @@ -234,7 +303,7 @@ function print_sch_status(io::IO, state, thunk; offset=0, limit=5, max_inputs=3) print(io, "($(status_string(thunk))) ") end println(io, "$(thunk.id): $(thunk.f)") - for (idx, input) in enumerate(thunk.syncdeps) + for (idx, input) in enumerate(thunk.options.syncdeps) if input isa WeakThunk input = Dagger.unwrap_weak(input) if input === nothing @@ -294,36 +363,47 @@ end chunktype(x) = typeof(x) signature(state, task::Thunk) = - signature(task.f, collect_task_inputs(state, task.inputs)) + signature(task.inputs[1], @view task.inputs[2:end]) function signature(f, args) - sig = DataType[chunktype(f)] + n_pos = count(Dagger.ispositional, args) + any_kw = any(!Dagger.ispositional, args) + kw_extra = any_kw ? 2 : 0 + sig = Vector{Any}(undef, 1+n_pos+kw_extra) + sig[1+kw_extra] = chunktype(f) + #=FIXME:REALLOC_N=# sig_kwarg_names = Symbol[] sig_kwarg_types = [] - for (pos, arg) in args - if arg isa Dagger.DTask + for idx in 1:length(args) + arg = args[idx] + value = Dagger.value(arg) + if value isa Dagger.DTask # Only occurs via manual usage of signature - arg = fetch(arg; raw=true) + value = fetch(value; raw=true) + end + if istask(value) + throw(ConcurrencyViolationError("Must call `collect_task_inputs!(state, task)` before calling `signature`")) end - T = chunktype(arg) - if pos === nothing - push!(sig, T) + T = chunktype(value) + if Dagger.ispositional(arg) + sig[1+idx+kw_extra] = T else - push!(sig_kwarg_names, pos) + push!(sig_kwarg_names, Dagger.pos_kw(arg)) push!(sig_kwarg_types, T) end end - if !isempty(sig_kwarg_names) + if any_kw NT = NamedTuple{(sig_kwarg_names...,), Base.to_tuple_type(sig_kwarg_types)} - pushfirst!(sig, NT) + sig[2] = NT @static if isdefined(Core, :kwcall) - pushfirst!(sig, typeof(Core.kwcall)) + sig[1] = typeof(Core.kwcall) else f_instance = chunktype(f).instance kw_f = Core.kwfunc(f_instance) - pushfirst!(sig, typeof(kw_f)) + sig[1] = typeof(kw_f) end end - return sig + #=FIXME:UNIQUE=# + return Signature(sig) end function can_use_proc(state, task, gproc, proc, opts, scope) @@ -373,18 +453,18 @@ function can_use_proc(state, task, gproc, proc, opts, scope) return false, scope end - # Check against f/args + # Check against function and arguments Tf = chunktype(task.f) if !Dagger.iscompatible_func(proc, opts, Tf) @dagdebug task :scope "Rejected $proc: Not compatible with function type ($Tf)" return false, scope end - for (_, arg) in task.inputs - arg = unwrap_weak_checked(arg) - if arg isa Thunk - arg = state.cache[arg] + for arg in task.inputs[2:end] + value = unwrap_weak_checked(Dagger.value(arg)) + if value isa Thunk + value = load_result(state, value) end - Targ = chunktype(arg) + Targ = chunktype(value) if !Dagger.iscompatible_arg(proc, opts, Targ) @dagdebug task :scope "Rejected $proc: Not compatible with argument type ($Targ)" return false, scope @@ -404,7 +484,7 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) time_util[T] * 1000^3 else get(state.signature_time_cost, sig, 1000^3) - end) + end)::UInt64 est_alloc_util = if alloc_util !== nothing && haskey(alloc_util, T) alloc_util[T] else @@ -434,27 +514,6 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) return true, est_time_util, est_alloc_util, est_occupancy end -function populate_processor_cache_list!(state, procs) - # Populate the cache if empty - if state.procs_cache_list[] === nothing - current = nothing - for p in map(x->x.pid, procs) - for proc in get_processors(OSProc(p)) - next = ProcessorCacheEntry(OSProc(p), proc) - if current === nothing - current = next - current.next = current - state.procs_cache_list[] = current - else - current.next = next - current = next - current.next = state.procs_cache_list[] - end - end - end - end -end - "Like `sum`, but replaces `nothing` entries with the average of non-`nothing` entries." function impute_sum(xs) total = 0 @@ -474,15 +533,16 @@ function impute_sum(xs) end "Collects all arguments for `task`, converting Thunk inputs to Chunks." -collect_task_inputs(state, task::Thunk) = - collect_task_inputs(state, task.inputs) -function collect_task_inputs(state, inputs) - new_inputs = Pair{Union{Symbol,Nothing},Any}[] - for (pos, input) in inputs - input = unwrap_weak_checked(input) - push!(new_inputs, pos => (istask(input) ? state.cache[input] : input)) - end - return new_inputs +collect_task_inputs!(state, task::Thunk) = + collect_task_inputs!(state, task.inputs) +function collect_task_inputs!(state, inputs) + for idx in 1:length(inputs) + input = unwrap_weak_checked(Dagger.value(inputs[idx])) + if istask(input) + inputs[idx].value = wrap_weak(load_result(state, input)) + end + end + return end """ @@ -491,20 +551,34 @@ current estimated per-processor compute pressure, and transfer costs for each `Chunk` argument to `task`. Returns `(procs, costs)`, with `procs` sorted in order of ascending cost. """ -function estimate_task_costs(state, procs, task, inputs) +function estimate_task_costs(state, procs, task; sig=nothing) + sorted_procs = Vector{Processor}(undef, length(procs)) + costs = Dict{Processor,Float64}() + estimate_task_costs!(sorted_procs, costs, state, procs, task; sig) + return sorted_procs, costs +end +@reuse_scope function estimate_task_costs!(sorted_procs, costs, state, procs, task; sig=nothing) tx_rate = state.transfer_rate[] # Find all Chunks - chunks = Chunk[] - for input in inputs - if input isa Chunk - push!(chunks, input) + chunks = @reusable_vector :estimate_task_costs_chunks Union{Chunk,Nothing} nothing 32 + chunks_cleanup = @reuse_defer_cleanup empty!(chunks) + for input in task.inputs + if Dagger.valuetype(input) <: Chunk + push!(chunks, Dagger.value(input)::Chunk) end end - costs = Dict{Processor,Float64}() + # Estimate the cost of executing the task itself + if sig === nothing + sig = signature(task.f, task.inputs) + end + est_time_util = get(state.signature_time_cost, sig, 1000^3) + + # Estimate total cost for executing this task on each candidate processor for proc in procs - chunks_filt = Iterators.filter(c->get_parent(processor(c))!=get_parent(proc), chunks) + gproc = get_parent(proc) + chunks_filt = Iterators.filter(c->get_parent(processor(c)) != gproc, chunks) # Estimate network transfer costs based on data size # N.B. `affinity(x)` really means "data size of `x`" @@ -513,19 +587,27 @@ function estimate_task_costs(state, procs, task, inputs) # TODO: Measure and model processor move overhead tx_cost = impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(state.worker_time_pressure[get_parent(proc).pid], proc, 0) - costs[proc] = est_time_util + (tx_cost/tx_rate) + # Add fixed cost for cross-worker task transfer (esimated at 1ms) + # TODO: Actually estimate/benchmark this + task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms + + # Compute final cost + costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end + chunks_cleanup() # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(procs)) - procs = getindex.(Ref(procs), P) + np = length(procs) + @reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin + copyto!(P, 1:np) + randperm!(P) + for idx in 1:np + sorted_procs[idx] = procs[P[idx]] + end + end # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - return procs, costs + sort!(sorted_procs, by=p->costs[p]) end """ diff --git a/src/sch_options.jl b/src/sch_options.jl new file mode 100644 index 000000000..bdf308cf3 --- /dev/null +++ b/src/sch_options.jl @@ -0,0 +1,32 @@ +""" + SchedulerOptions + +Stores DAG-global options to be passed to the Dagger.Sch scheduler. + +# Arguments +- `single::Int=0`: (Deprecated) Force all work onto worker with specified id. + `0` disables this option. +- `proclist=nothing`: (Deprecated) Force scheduler to use one or more + processors that are instances/subtypes of a contained type. Alternatively, a + function can be supplied, and the function will be called with a processor as + the sole argument and should return a `Bool` result to indicate whether or not + to use the given processor. `nothing` enables all default processors. +- `allow_errors::Bool=false`: Allow thunks to error without affecting + non-dependent thunks. +- `checkpoint=nothing`: If not `nothing`, uses the provided function to save + the final result of the current scheduler invocation to persistent storage, for + later retrieval by `restore`. +- `restore=nothing`: If not `nothing`, uses the provided function to return the + (cached) final result of the current scheduler invocation, were it to execute. + If this returns a `Chunk`, all thunks will be skipped, and the `Chunk` will be + returned. If `nothing` is returned, restoring is skipped, and the scheduler + will execute as usual. If this function throws an error, restoring will be + skipped, and the error will be displayed. +""" +Base.@kwdef struct SchedulerOptions + single::Union{Int,Nothing} = nothing + proclist = nothing + allow_errors::Union{Bool,Nothing} = false + checkpoint = nothing + restore = nothing +end diff --git a/src/stream.jl b/src/stream.jl index 885a9e931..bf1ea4537 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -298,7 +298,7 @@ end function (dpm::DestPostMigration)(store, unsent) STREAM_THUNK_ID[] = dpm.thunk_id @assert !in_task() - tls = DTaskTLS(OSProc(), typemax(UInt64), nothing, [], dpm.cancel_token) + tls = DTaskTLS(OSProc(), typemax(UInt64), nothing, [], dpm.cancel_token, false) set_tls!(tls) return dpm.f(store, unsent) end @@ -391,7 +391,7 @@ function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask} end function initialize_streaming!(self_streams, spec, task) - @assert !isa(spec.f, StreamingFunction) "Task is already in streaming form" + @assert !isa(value(spec.fargs[1]), StreamingFunction) "Task is already in streaming form" # Calculate the return type of the called function T_old = Base.uniontypes(task.metadata.return_type) @@ -401,33 +401,34 @@ function initialize_streaming!(self_streams, spec, task) T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any # Get input buffer configuration - input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + input_buffer_amount = something(spec.options.stream_input_buffer_amount, 1) if input_buffer_amount <= 0 throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) end # Get output buffer configuration - output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + output_buffer_amount = something(spec.options.stream_output_buffer_amount, 1) if output_buffer_amount <= 0 throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) end # Create the Stream - buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + buffer_type = something(spec.options.stream_buffer_type, ProcessRingBuffer) stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) self_streams[task.uid] = stream # Get max evaluation count - max_evals = get(spec.options, :stream_max_evals, -1) + max_evals = something(spec.options.stream_max_evals, -1) if max_evals == 0 throw(ArgumentError("stream_max_evals cannot be 0")) end # Wrap the function in a StreamingFunction - spec.f = StreamingFunction(spec.f, stream, max_evals) + spec.fargs[1].value = StreamingFunction(value(spec.fargs[1]), stream, max_evals) # Mark the task as non-blocking - spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + spec.options.occupancy = @something(spec.options.occupancy, Dict()) + spec.options.occupancy[Any] = 0 # Register Stream globally remotecall_wait(1, task.uid, stream) do uid, stream @@ -670,7 +671,8 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) our_stream = self_streams[task.uid] # Adapt args to accept Stream output of other streaming tasks - for (idx, (pos, arg)) in enumerate(spec.args) + for (idx, pos_arg) in enumerate(spec.fargs) + arg = value(pos_arg) if arg isa DTask # Check if this is a streaming task if haskey(self_streams, arg.uid) @@ -684,7 +686,7 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) # FIXME: Be configurable input_fetcher = RemoteChannelFetcher() other_stream_handle = Stream(other_stream) - spec.args[idx] = pos => other_stream_handle + pos_arg.value = other_stream_handle our_stream.store.input_streams[arg.uid] = other_stream_handle our_stream.store.input_fetchers[arg.uid] = input_fetcher @@ -696,18 +698,6 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) end end end - - # Filter out all streaming options - to_filter = (:stream_buffer_type, - :stream_input_buffer_amount, :stream_output_buffer_amount, - :stream_max_evals) - spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), - Base.pairs(spec.options))) - if haskey(spec.options, :propagates) - propagates = filter(opt -> !(opt in to_filter), - spec.options.propagates) - spec.options = merge(spec.options, (;propagates)) - end end # Notify Streams of any new waiters diff --git a/src/submission.jl b/src/submission.jl index f23539271..32cdc6d05 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -1,240 +1,317 @@ +mutable struct PayloadOne + uid::UInt + future::ThunkFuture + fargs::Vector{Argument} + options::Options + reschedule::Bool + + PayloadOne() = new() + PayloadOne(uid::UInt, future::ThunkFuture, + fargs::Vector{Argument}, options::Options, reschedule::Bool) = + new(uid, future, fargs, options, reschedule) +end +function unset!(p::PayloadOne, _) + p.uid = 0 + p.future = EMPTY_PAYLOAD_ONE.future + p.fargs = EMPTY_PAYLOAD_ONE.fargs + p.options = EMPTY_PAYLOAD_ONE.options + p.reschedule = false +end +const EMPTY_PAYLOAD_ONE = PayloadOne(UInt(0), ThunkFuture(), Argument[], Options(), false) +mutable struct PayloadMulti + ntasks::Int + uid::Vector{UInt} + future::Vector{ThunkFuture} + fargs::Vector{Vector{Argument}} + options::Vector{Options} + reschedule::Bool +end +const AnyPayload = Union{PayloadOne, PayloadMulti} +function payload_extract(f, payload::PayloadMulti, i::Integer) + take_or_alloc!(PAYLOAD_ONE_CACHE[]) do p1 + p1.uid = payload.uid[i] + p1.future = payload.future[i] + p1.fargs = payload.fargs[i] + p1.options = payload.options[i] + p1.reschedule = true + return f(p1) + end +end +const PAYLOAD_ONE_CACHE = TaskLocalValue{ReusableCache{PayloadOne,Nothing}}(()->ReusableCache(PayloadOne, nothing, 1)) + +const THUNK_SPEC_CACHE = TaskLocalValue{ReusableCache{ThunkSpec,Nothing}}(()->ReusableCache(ThunkSpec, nothing, 1)) + # Remote -function eager_submit_internal!(@nospecialize(payload)) +function eager_submit_internal!(payload::AnyPayload) ctx = Dagger.Sch.eager_context() state = Dagger.Sch.EAGER_STATE[] task = current_task() tid = 0 return eager_submit_internal!(ctx, state, task, tid, payload) end -function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{UInt64,Int}()) - @nospecialize payload - ntasks, uid, future, ref, f, args, options, reschedule = payload - - if uid isa Vector - thunk_ids = Sch.ThunkID[] - for i in 1:ntasks - tid = eager_submit_internal!(ctx, state, task, tid, - (1, uid[i], future[i], ref[i], - f[i], args[i], options[i], - false); uid_to_tid) - push!(thunk_ids, tid) - uid_to_tid[uid[i]] = tid.id +eager_submit_internal!(ctx, state, task, tid, payload::Tuple{<:AnyPayload}) = + eager_submit_internal!(ctx, state, task, tid, payload[1]) +const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}}(()->ReusableCache(Dict{UInt64,Int}, nothing, 1)) +@reuse_scope function eager_submit_internal!(ctx, state, task, tid, payload::AnyPayload; uid_to_tid=nothing) + maybe_take_or_alloc!(UID_TO_TID_CACHE[], uid_to_tid) do uid_to_tid + if payload isa PayloadMulti + thunk_ids = Sch.ThunkID[] + for i in 1:payload.ntasks + tid = payload_extract(payload, i) do p1 + eager_submit_internal!(ctx, state, task, tid, p1; + uid_to_tid) + end + push!(thunk_ids, tid) + uid_to_tid[payload.uid[i]] = tid.id + end + @lock state.lock begin + put!(state.chan, Sch.RescheduleSignal()) + end + return thunk_ids end - put!(state.chan, Sch.RescheduleSignal()) - return thunk_ids - end + payload::PayloadOne - id = next_id() + uid, future = payload.uid, payload.future + fargs, options, reschedule = payload.fargs, payload.options, payload.reschedule - timespan_start(ctx, :add_thunk, (;thunk_id=id), (;f, args, options, uid)) + id = next_id() - # Lookup DTask/ThunkID -> Thunk - old_args = copy(args) - args::Vector{Any} - syncdeps = if haskey(options, :syncdeps) - collect(options.syncdeps) - else - nothing - end::Union{Vector{Any},Nothing} - lock(Sch.EAGER_ID_MAP) do id_map - for (idx, (pos, arg)) in enumerate(args) - # FIXME: Switch to Union{Symbol,Int} to preserve positional information - pos::Union{Symbol,Nothing} - newarg = if arg isa DTask - arg_uid = arg.uid - arg_tid = if haskey(id_map, arg_uid) - id_map[arg_uid] - else - uid_to_tid[arg_uid] - end - state.thunk_dict[arg_tid] - elseif arg isa Sch.ThunkID - arg_tid = arg.id - state.thunk_dict[arg_tid] - elseif arg isa Chunk - # N.B. Different Chunks with the same DRef handle will hash to the same slot, - # so we just pick an equivalent Chunk as our upstream - if haskey(state.waiting_data, arg) - newarg = nothing - for other in keys(state.waiting_data) - if other isa Chunk && other.handle == arg.handle - newarg = other - break + @maybelog ctx timespan_start(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options, uid)) + + old_fargs = @reusable_vector :eager_submit_internal!_old_fargs Argument Argument(ArgPosition(), nothing) 32 + old_fargs_cleanup = @reuse_defer_cleanup empty!(old_fargs) + append!(old_fargs, Iterators.map(copy, fargs)) + + syncdeps_vec = @reusable_vector :eager_submit_interal!_syncdeps_vec Any nothing 32 + syncdeps_vec_cleanup = @reuse_defer_cleanup empty!(syncdeps_vec) + if options.syncdeps !== nothing + append!(syncdeps_vec, options.syncdeps) + end + + # Lookup DTask/ThunkID -> Thunk + # FIXME: Don't lock if no DTask args + lock(Sch.EAGER_ID_MAP) do id_map + for (idx, arg) in enumerate(fargs) + if valuetype(arg) <: DTask + arg_uid = (value(arg)::DTask).uid + arg_tid = if haskey(id_map, arg_uid) + id_map[arg_uid] + else + uid_to_tid[arg_uid] + end + @lock state.lock begin + @inbounds fargs[idx] = Argument(arg.pos, state.thunk_dict[arg_tid]) + end + elseif valuetype(arg) <: Sch.ThunkID + arg_tid = (value(arg)::Sch.ThunkID).id + @lock state.lock begin + @inbounds fargs[idx] = Argument(arg.pos, state.thunk_dict[arg_tid]) + end + elseif valuetype(arg) <: Chunk + # N.B. Different Chunks with the same DRef handle will hash to the same slot, + # so we just pick an equivalent Chunk as our upstream + chunk = value(arg)::Chunk + function find_equivalent_chunk(state, chunk::C) where {C<:Chunk} + @lock state.lock begin + if haskey(state.equiv_chunks, chunk.handle) + return state.equiv_chunks[chunk.handle]::C + else + state.equiv_chunks[chunk.handle] = chunk + return chunk + end end end - @assert newarg !== nothing - arg = newarg::Chunk + chunk = find_equivalent_chunk(state, chunk) + #=FIXME:UNIQUE=# + @inbounds fargs[idx] = Argument(arg.pos, WeakChunk(chunk)) end - WeakChunk(arg) - else - arg end - @inbounds args[idx] = pos => newarg - end - if syncdeps === nothing - return - end - for (idx, dep) in enumerate(syncdeps) - newdep = if dep isa DTask - tid = if haskey(id_map, dep.uid) - id_map[dep.uid] - else - uid_to_tid[dep.uid] + # TODO: Iteration protocol would be faster + for idx in 1:length(syncdeps_vec) + dep = syncdeps_vec[idx] + if dep isa DTask + tid = if haskey(id_map, dep.uid) + id_map[dep.uid] + else + uid_to_tid[dep.uid] + end + @lock state.lock begin + @inbounds syncdeps_vec[idx] = state.thunk_dict[tid] + end + elseif dep isa Sch.ThunkID + tid = dep.id + @lock state.lock begin + @inbounds syncdeps_vec[idx] = state.thunk_dict[tid] + end end - state.thunk_dict[tid] - elseif dep isa Sch.ThunkID - tid = dep.id - state.thunk_dict[tid] + end + end + if !isempty(syncdeps_vec) || any(arg->istask(value(arg)), fargs) + if options.syncdeps === nothing + options.syncdeps = Set{Any}() else - dep + empty!(options.syncdeps) + end + syncdeps = options.syncdeps + for dep in syncdeps_vec + push!(syncdeps, dep) + end + for arg in fargs + if istask(value(arg)) + push!(syncdeps, value(arg)) + end end - @inbounds syncdeps[idx] = newdep end - end - if syncdeps !== nothing - options = merge(options, (;syncdeps)) - end + syncdeps_vec_cleanup() - GC.@preserve old_args args begin - # Create the `Thunk` - thunk = Thunk(f, args...; id, options...) + GC.@preserve old_fargs fargs begin + # Create the `Thunk` + thunk = take_or_alloc!(THUNK_SPEC_CACHE[]) do thunk_spec + thunk_spec.fargs = fargs + thunk_spec.id = id + thunk_spec.options = options + return Thunk(thunk_spec) + end - # Create a `DRef` to `thunk` so that the caller can preserve it - thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice(), - destructor=UnrefThunkByUser(thunk)) - thunk_id = Sch.ThunkID(thunk.id, thunk_ref) + # Create a `DRef` to `thunk` so that the caller can preserve it + thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice(), + destructor=UnrefThunk(uid, thunk, state)) + #=FIXME:UNIQUE=# + thunk_id = Sch.ThunkID(thunk.id, thunk_ref) - # Attach `thunk` within the scheduler - state.thunk_dict[thunk.id] = WeakThunk(thunk) - Sch.reschedule_syncdeps!(state, thunk) - @dagdebug thunk :submit "Added to scheduler" - if future !== nothing - # Ensure we attach a future before the thunk is scheduled - Sch._register_future!(ctx, state, task, tid, (future, thunk_id, false)) - @dagdebug thunk :submit "Registered future" - end - if ref !== nothing - # Preserve the `DTaskFinalizer` through `thunk` - thunk.eager_ref = ref - end - state.valid[thunk] = nothing + @lock state.lock begin + # Attach `thunk` within the scheduler + state.thunk_dict[thunk.id] = WeakThunk(thunk) + #=FIXME:REALLOC=# + Sch.reschedule_syncdeps!(state, thunk) + old_fargs_cleanup() # reschedule_syncdeps! preserves all referenced tasks/chunks + n_upstreams = haskey(state.waiting, thunk) ? length(state.waiting[thunk]) : 0 + @dagdebug thunk :submit "Added to scheduler with $n_upstreams upstreams" + if future !== nothing + # Ensure we attach a future before the thunk is scheduled + Sch._register_future!(ctx, state, task, tid, (future, thunk_id, false)) + @dagdebug thunk :submit "Registered future" + end + state.valid[thunk] = nothing - # Register Eager UID -> Sch TID - lock(Sch.EAGER_ID_MAP) do id_map - id_map[uid] = thunk.id - end + # Register Eager UID -> Sch TID + lock(Sch.EAGER_ID_MAP) do id_map + id_map[uid] = thunk.id + end - # Tell the scheduler that it has new tasks to schedule - if reschedule - put!(state.chan, Sch.RescheduleSignal()) - end + # Tell the scheduler that it has new tasks to schedule + if reschedule + put!(state.chan, Sch.RescheduleSignal()) + end + end - timespan_finish(ctx, :add_thunk, (;thunk_id=id), (;f, args, options, uid)) + @maybelog ctx timespan_finish(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options, uid)) - return thunk_id + return thunk_id + end end end -struct UnrefThunkByUser +struct UnrefThunk + uid::UInt thunk::Thunk + state end -function (unref::UnrefThunkByUser)() - Sch.errormonitor_tracked("unref thunk $(unref.thunk.id)", Threads.@spawn begin - # This thunk is no longer referenced by the user, mark it as ready to be - # cleaned up as eagerly as possible (or do so now) - thunk = unref.thunk - state = Sch.EAGER_STATE[] - if state === nothing - return +function (unref::UnrefThunk)() + name = unref.uid != UInt(0) ? "unref DTask $(unref.uid) => Thunk $(unref.thunk.id)" : "unref Thunk $(unref.thunk.id)" + Sch.errormonitor_tracked(name, Threads.@spawn begin + if unref.uid != UInt(0) + lock(Sch.EAGER_ID_MAP) do id_map + delete!(id_map, unref.uid) + end end + # The associated DTask is no longer referenced by the user, so mark the + # thunk as ready to be cleaned up as eagerly as possible (or do so now) + thunk = unref.thunk + state = unref.state @lock state.lock begin - if !Sch.delete_unused_task!(state, thunk) - # Register for deletion upon thunk completion - push!(state.thunks_to_delete, thunk) + thunk.eager_accessible = false + Sch.delete_unused_task!(state, thunk) + end + + if unref.uid != UInt(0) + # Cleanup EAGER_THUNK_STREAMS if this is a streaming DTask + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, unref.uid) + delete!(global_streams, unref.uid) + end end - # TODO: On success, walk down to children, as a fast-path end end) end - # Local -> Remote -function eager_submit!(ntasks, uid, future, finalizer_ref, f, args, options) +function eager_submit!(payload::AnyPayload) if Dagger.in_task() h = Dagger.sch_handle() - return exec!(eager_submit_internal!, h, ntasks, uid, future, finalizer_ref, f, args, options, true) + return exec!(eager_submit_internal!, h, payload) elseif myid() != 1 - return remotecall_fetch(1, (ntasks, uid, future, finalizer_ref, f, args, options, true)) do payload - @nospecialize payload + return remotecall_fetch(1, payload) do payload Sch.init_eager() - state = Dagger.Sch.EAGER_STATE[] - lock(state.lock) do - eager_submit_internal!(payload) - end + eager_submit_internal!(payload) end else Sch.init_eager() - state = Dagger.Sch.EAGER_STATE[] - return lock(state.lock) do - eager_submit_internal!((ntasks, uid, future, finalizer_ref, - f, args, options, - true)) - end + return eager_submit_internal!(payload) end end # Submission -> Local -function eager_process_elem_submission_to_local(id_map, x) - @nospecialize x - @assert !isa(x, Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" - if x isa Dagger.DTask && haskey(id_map, x.uid) - return Sch.ThunkID(id_map[x.uid], x.thunk_ref) - else - return x +function eager_process_elem_submission_to_local!(id_map, arg::Argument) + T = valuetype(arg) + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -# TODO: This can probably operate in-place -function eager_process_args_submission_to_local(id_map, spec::Pair{DTaskSpec,DTask}) - return Base.mapany(first(spec).args) do pos_x - pos, x = pos_x - return pos => eager_process_elem_submission_to_local(id_map, x) +function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) + spec, task = spec_pair + for arg in spec.fargs + eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local(id_map, specs::Vector{Pair{DTaskSpec,DTask}}) - return Base.mapany(specs) do spec - eager_process_args_submission_to_local(id_map, spec) +function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) + for spec_pair in spec_pairs + eager_process_args_submission_to_local!(id_map, spec_pair) end end -function eager_process_options_submission_to_local(id_map, options::NamedTuple) - @nospecialize options - if haskey(options, :syncdeps) +function eager_process_options_submission_to_local!(id_map, options::Options) + if options.syncdeps !== nothing raw_syncdeps = options.syncdeps syncdeps = Set{Any}() for raw_dep in raw_syncdeps - push!(syncdeps, eager_process_elem_submission_to_local(id_map, raw_dep)) + if raw_dep isa DTask + push!(syncdeps, Sch.ThunkID(id_map[raw_dep.uid], raw_dep.thunk_ref)) + elseif raw_dep isa Sch.ThunkID + push!(syncdeps, raw_dep) + else + error("Invalid syncdep type: $(typeof(raw_dep))") + end end - return merge(options, (;syncdeps)) - else - return options + options.syncdeps = syncdeps end end function DTaskMetadata(spec::DTaskSpec) - f = spec.f isa StreamingFunction ? spec.f.f : spec.f - arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + f = value(spec.fargs[1]) + f = f isa StreamingFunction ? f.f : f + arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) return_type = Base.promote_op(f, arg_types...) return DTaskMetadata(return_type) end function eager_spawn(spec::DTaskSpec) - # Generate new DTask + # Generate new unlaunched DTask uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) - finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) - - # Create unlaunched DTask - return DTask(uid, future, metadata, finalizer_ref) + return DTask(uid, future, metadata) end chunktype(t::DTask) = t.metadata.return_type @@ -244,16 +321,15 @@ function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - local args, options lock(Sch.EAGER_ID_MAP) do id_map - args = eager_process_args_submission_to_local(id_map, spec=>task) - options = eager_process_options_submission_to_local(id_map, spec.options) + eager_process_args_submission_to_local!(id_map, spec=>task) + eager_process_options_submission_to_local!(id_map, spec.options) end # Submit the task - thunk_id = eager_submit!(1, - task.uid, task.future, task.finalizer_ref, - spec.f, args, options) + #=FIXME:REALLOC=# + thunk_id = eager_submit!(PayloadOne(task.uid, task.future, + spec.fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) @@ -264,20 +340,23 @@ function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) eager_assign_name!(spec, task) end + #=FIXME:REALLOC_N=# uids = [task.uid for (_, task) in specs] futures = [task.future for (_, task) in specs] - finalizer_refs = [task.finalizer_ref for (_, task) in specs] # Get all functions, args/kwargs, and options - all_fs = Any[spec.f for (spec, _) in specs] - all_args = lock(Sch.EAGER_ID_MAP) do id_map + #=FIXME:REALLOC_N=# + all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local(id_map, specs) + eager_process_args_submission_to_local!(id_map, specs) + [spec.fargs for (spec, _) in specs] end - all_options = Any[spec.options for (spec, _) in specs] + all_options = Options[spec.options for (spec, _) in specs] # Submit the tasks - thunk_ids = eager_submit!(ntasks, uids, futures, finalizer_refs, all_fs, all_args, all_options) + #=FIXME:REALLOC=# + thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, + all_fargs, all_options, true)) for i in 1:ntasks task = specs[i][2] task.thunk_ref = thunk_ids[i].ref @@ -286,8 +365,7 @@ end function eager_assign_name!(spec::DTaskSpec, task::DTask) # Assign a name, if specified - if haskey(spec.options, :name) + if spec.options.name !== nothing Dagger.logs_annotate!(task, spec.options.name) - spec.options = (;filter(x -> x[1] != :name, Base.pairs(spec.options))...) end end diff --git a/src/task-tls.jl b/src/task-tls.jl index 5c7d0375b..29d5db082 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -4,13 +4,20 @@ mutable struct DTaskTLS processor::Processor sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle - task_spec::Vector{Any} # FIXME: TaskSpec + task_spec::Any # FIXME: TaskSpec cancel_token::CancelToken + logging_enabled::Bool end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) -Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) +Base.copy(tls::DTaskTLS) = + DTaskTLS(tls.processor, + tls.sch_uid, + tls.sch_handle, + tls.task_spec, + tls.cancel_token, + tls.logging_enabled) """ get_tls() -> DTaskTLS @@ -20,12 +27,17 @@ Gets all Dagger TLS variable as a `DTaskTLS`. get_tls() = DTASK_TLS[]::DTaskTLS """ - set_tls!(tls) + set_tls!(tls::NamedTuple) Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ function set_tls!(tls) - DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + DTASK_TLS[] = DTaskTLS(tls.processor, + tls.sch_uid, + tls.sch_handle, + tls.task_spec, + tls.cancel_token, + tls.logging_enabled) end """ @@ -79,3 +91,10 @@ Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be cancelled gracefully, otherwise it will be forced. """ task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) + +""" + task_logging_enabled() -> Bool + +Returns `true` if logging is enabled for the current [`DTask`](@ref), else `false`. +""" +task_logging_enabled() = get_tls().logging_enabled diff --git a/src/threadproc.jl b/src/threadproc.jl index b75c90ca3..6ac75db8e 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -16,7 +16,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n result = Ref{Any}() task = Task() do set_tls!(tls) - TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) + if task_logging_enabled() + TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) + end result[] = @invokelatest f(args...; kwargs...) return end diff --git a/src/thunk.jl b/src/thunk.jl index 659ee9a2c..e4299aae1 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -3,14 +3,23 @@ export Thunk, delayed const ID_COUNTER = Threads.Atomic{Int}(1) next_id() = Threads.atomic_add!(ID_COUNTER, 1) -function filterany(f::Base.Callable, xs) - xs_filt = Any[] - for x in xs - if f(x) - push!(xs_filt, x) - end - end - return xs_filt +const EMPTY_ARGS = Argument[] +const EMPTY_SYNCDEPS = Set{Any}() +Base.@kwdef mutable struct ThunkSpec + fargs::Vector{Argument} = EMPTY_ARGS + id::Int = 0 + cache_ref::Any = nothing + affinity::Union{Pair{OSProc,Int}, Nothing} = nothing + options::Union{Options, Nothing} = nothing +end +function unset!(spec::ThunkSpec, _) + spec.fargs = EMPTY_ARGS + spec.id = 0 + spec.cache_ref = nothing + spec.affinity = nothing + compute_scope = DefaultScope() + result_scope = AnyScope() + spec.options = nothing end """ @@ -35,93 +44,103 @@ julia> collect(t) # computes the result and returns it to the current process ``` ## Arguments -- `f`: The function to be called upon execution of the `Thunk`. -- `args`: The arguments to be passed to the `Thunk`. +- `fargs`: The function and arguments to be called upon execution of the `Thunk`. - `kwargs`: The properties describing unique behavior of this `Thunk`. Details for each property are described in the next section. - `option=value`: The same as passing `kwargs` to `delayed`. -## Public Properties -- `meta::Bool=false`: If `true`, instead of fetching cached arguments from -`Chunk`s and passing the raw arguments to `f`, instead pass the `Chunk`. Useful -for doing manual fetching or manipulation of `Chunk` references. Non-`Chunk` -arguments are still passed as-is. -- `processor::Processor=OSProc()` - The processor associated with `f`. Useful if -`f` is a callable struct that exists on a given processor and should be -transferred appropriately. -- `scope::Dagger.AbstractScope=DefaultScope()` - The scope associated with `f`. -Useful if `f` is a function or callable struct that may only be transferred to, -and executed within, the specified scope. - ## Options -- `options`: A `Sch.ThunkOptions` struct providing the options for the `Thunk`. +- `options`: An `Options` struct providing the options for the `Thunk`. If omitted, options can also be specified by passing key-value pairs as `kwargs`. """ mutable struct Thunk - f::Any # usually a Function, but could be any callable - inputs::Vector{Pair{Union{Symbol,Nothing},Any}} # TODO: Use `ImmutableArray` in 1.8 - syncdeps::Set{Any} + inputs::Vector{Argument} # TODO: Use `ImmutableArray` in 1.8 id::Int - get_result::Bool # whether the worker should send the result or only the metadata - meta::Bool - persist::Bool # don't `free!` result after computing - cache::Bool # release the result giving the worker an opportunity to - # cache it cache_ref::Any - affinity::Union{Nothing, Pair{OSProc, Int}} - eager_ref::Union{DRef,Nothing} - options::Any # stores scheduler-specific options - propagates::Tuple # which options we'll propagate - compute_scope::AbstractScope - result_scope::AbstractScope - function Thunk(f, xs...; - syncdeps=nothing, - id::Int=next_id(), - get_result::Bool=false, - meta::Bool=false, - persist::Bool=false, - cache::Bool=false, - cache_ref=nothing, - affinity=nothing, - eager_ref=nothing, - processor=nothing, - scope=DefaultScope(), - compute_scope=scope, - result_scope=AnyScope(), - options=nothing, - propagates=(), - kwargs... - ) - - xs = Base.mapany(identity, xs) - syncdeps_set = Set{Any}(filterany(is_task_or_chunk, Base.mapany(last, xs))) - if syncdeps !== nothing - for dep in syncdeps - push!(syncdeps_set, dep) - end - end - @assert all(x->x isa Pair, xs) - if options !== nothing - @assert isempty(kwargs) - new(f, xs, syncdeps_set, id, get_result, meta, persist, cache, - cache_ref, affinity, eager_ref, options, propagates, compute_scope, result_scope) + affinity::Union{Pair{OSProc,Int}, Nothing} + options::Union{Options, Nothing} # stores task options + eager_accessible::Bool + sch_accessible::Bool + finished::Bool + function Thunk(spec::ThunkSpec) + return new(spec.fargs, spec.id, + spec.cache_ref, spec.affinity, + spec.options, + true, true, false) + end +end +function Thunk(f, xs...; + syncdeps=nothing, + id::Int=next_id(), + cache_ref=nothing, + affinity=nothing, + options=nothing, + propagates=(), + kwargs... + ) + + spec = ThunkSpec() + if !(f isa Argument) + f = Argument(ArgPosition(true, 0, :NULL), f) + end + spec.fargs = Vector{Argument}(undef, length(xs)+1) + spec.fargs[1] = f + for idx in 1:length(xs) + x = xs[idx] + if x isa Argument + spec.fargs[idx+1] = x else - new(f, xs, syncdeps_set, id, get_result, meta, persist, cache, - cache_ref, affinity, eager_ref, Sch.ThunkOptions(;kwargs...), - propagates, compute_scope, result_scope) + @assert x isa Pair "Invalid Thunk argument: $x" + spec.fargs[idx+1] = Argument(something(x.first, idx), x.second) end end + if options === nothing + options = Options() + end + spec.options = options::Options + if options.syncdeps === nothing + options.syncdeps = Set{Any}() + end + syncdeps_set = options.syncdeps + for idx in 2:length(spec.fargs) + x = value(spec.fargs[idx]) + if is_task_or_chunk(x) + push!(syncdeps_set, x) + end + end + if syncdeps !== nothing + for dep in syncdeps + push!(syncdeps_set, dep) + end + end + spec.id = id + if kwargs !== nothing + options_merge!(options, (;kwargs...)) + end + if haskey(kwargs, :cache) + @warn "The cache argument is deprecated, as it is now always true" maxlog=1 + end + spec.cache_ref = cache_ref + spec.affinity = affinity + return Thunk(spec) end Serialization.serialize(io::AbstractSerializer, t::Thunk) = throw(ArgumentError("Cannot serialize a Thunk")) +function Base.getproperty(thunk::Thunk, field::Symbol) + if field == :f + return unwrap_weak_checked(value(first(thunk.inputs))) + else + return getfield(thunk, field) + end +end function affinity(t::Thunk) if t.affinity !== nothing return t.affinity end - if t.cache && t.cache_ref !== nothing + if t.cache_ref !== nothing aff_vec = affinity(t.cache_ref) else aff = Dict{OSProc,Int}() @@ -152,13 +171,32 @@ end is_task_or_chunk(x) = istask(x) -function args_kwargs_to_pairs(args, kwargs) - args_kwargs = Pair{Union{Symbol,Nothing},Any}[] - for arg in args - push!(args_kwargs, nothing => arg) +function args_kwargs_to_arguments(f, args, kwargs) + @nospecialize f args kwargs + args_kwargs = Argument[] + push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) + for idx in 1:length(args) + arg = args[idx] + push!(args_kwargs, Argument(idx, arg)) + end + for (kw, value) in kwargs + push!(args_kwargs, Argument(kw, value)) end - for kwarg in kwargs - push!(args_kwargs, kwarg[1] => kwarg[2]) + return args_kwargs +end +function args_kwargs_to_arguments(f, args) + @nospecialize f args + args_kwargs = Argument[] + push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) + pos_ctr = 1 + for idx in 1:length(args) + pos, arg = args[idx]::Pair + if pos === nothing + push!(args_kwargs, Argument(pos_ctr, arg)) + pos_ctr += 1 + else + push!(args_kwargs, Argument(pos, arg)) + end end return args_kwargs end @@ -172,13 +210,21 @@ Creates a [`Thunk`](@ref) object which can be executed later, which will call resulting `Thunk`. """ function _delayed(f, options::Options) - (args...; kwargs...) -> Thunk(f, args_kwargs_to_pairs(args, kwargs)...; options.options...) + (args...; kwargs...) -> Thunk(args_kwargs_to_arguments(f, args, kwargs)...; options) end function delayed(f, options::Options) @warn "`delayed` is deprecated. Use `Dagger.@spawn` or `Dagger.spawn` instead." maxlog=1 return _delayed(f, options) end -delayed(f; kwargs...) = delayed(f, Options(;kwargs...)) +function delayed(f; options=nothing, kwargs...) + if options !== nothing + options = options::Options + else + options = Options() + end + options_merge!(options, kwargs; override=true) + return delayed(f, options) +end "A weak reference to a `Thunk`." struct WeakThunk @@ -195,22 +241,31 @@ function unwrap_weak_checked(t::WeakThunk) t end unwrap_weak_checked(t) = t +wrap_weak(t::Thunk) = WeakThunk(t) +wrap_weak(t::WeakThunk) = t +wrap_weak(t) = t +isweak(t::WeakThunk) = true +isweak(t::Thunk) = false +isweak(t) = true Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) +chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) "A summary of the data contained in a Thunk, which can be safely serialized." struct ThunkSummary id::Int - f - inputs::Vector{Pair{Union{Symbol,Nothing},Any}} + inputs::Vector{Argument} end inputs(t::ThunkSummary) = t.inputs Base.show(io::IO, t::ThunkSummary) = show_thunk(io, t) function Base.convert(::Type{ThunkSummary}, t::Thunk) - return ThunkSummary(t.id, - t.f, - map(pos_inp->istask(pos_inp[2]) ? pos_inp[1]=>convert(ThunkSummary, pos_inp[2]) : pos_inp, - t.inputs)) + args = map(copy, t.inputs) + for arg in args + if istask(value(arg)) + arg.value = convert(ThunkSummary, value(arg)) + end + end + return ThunkSummary(t.id, args) end function Base.convert(::Type{ThunkSummary}, t::WeakThunk) t = unwrap_weak(t) @@ -245,17 +300,19 @@ function Base.showerror(io::IO, ex::DTaskFailedException) function thunk_string(t) Tinputs = Any[] - for (_, input) in t.inputs - if istask(input) - push!(Tinputs, "DTask(id=$(input.id))") + for input in @view t.inputs[2:end] + x = value(input) + if istask(x) + push!(Tinputs, "DTask(id=$(x.id))") else - push!(Tinputs, input) + push!(Tinputs, x) end end + f = value(t.inputs[1]) t_sig = if length(Tinputs) <= 4 - "$(t.f)($(join(Tinputs, ", ")))" + "$(f)($(join(Tinputs, ", ")))" else - "$(t.f)($(length(Tinputs)) inputs...)" + "$(f)($(length(Tinputs)) inputs...)" end return "DTask(id=$(t.id), $t_sig)" end @@ -338,7 +395,7 @@ These options control a variety of properties of the resulting `DTask`: - `scope`: The execution "scope" of the task, which determines where the task will run. By default, the task will run on the first available compute resource. If you have multiple compute resources, you can specify a scope to run the task on a specific resource. For example, `Dagger.@spawn scope=Dagger.scope(worker=2) do_something(1, 3.0)` would run `do_something(1, 3.0)` on worker 2. - `meta`: If `true`, instead of the scheduler automatically fetching values from other tasks, the raw `Chunk` objects will be passed to `f`. Useful for doing manual fetching or manipulation of `Chunk` references. Non-`Chunk` arguments are still passed as-is. -Other options exist; see `Dagger.Sch.ThunkOptions` for the full list. +Other options exist; see `Dagger.Options` for the full list. This macro is a semi-thin wrapper around `Dagger.spawn` - it creates a call to `Dagger.spawn` on `f` with arguments `args` and keyword arguments `kwargs`, and @@ -466,27 +523,45 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs - # Get all options and determine which propagate beyond this task - options = get_options() - propagates = get(options, :propagates, ()) - propagates = Tuple(unique(Symbol[propagates..., keys(options)...])) + # Get all scoped options and determine which propagate beyond this task + scoped_options = get_options()::NamedTuple + if haskey(scoped_options, :propagates) + if scoped_options.propagates isa Tuple + propagates = Symbol[scoped_options.propagates...] + else + propagates = scoped_options.propagates::Vector{Symbol} + end + else + propagates = Symbol[] + end + append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) + + # Merge all passed options if length(args) >= 1 && first(args) isa Options - spawn_options = first(args).options - options = merge(options, spawn_options) + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) args = args[2:end] + else + task_options = Options() end + # N.B. Merges into task_options + options_merge!(task_options, scoped_options; override=false) # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_pairs(args, kwargs) + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) # Get task queue, and don't let it propagate - task_queue = get_options(:task_queue, DefaultTaskQueue()) - options = NamedTuple(filter(opt->opt[1] != :task_queue, Base.pairs(options))) - propagates = filter(prop->prop != :task_queue, propagates) - options = merge(options, (;propagates)) + task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue + filter!(prop -> prop != :task_queue, propagates) + if task_options.propagates !== nothing + append!(task_options.propagates, propagates) + else + task_options.propagates = propagates + end + unique!(task_options.propagates) # Construct task spec and handle - spec = DTaskSpec(f, args_kwargs, options) + spec = DTaskSpec(args_kwargs, task_options) task = eager_spawn(spec) # Enqueue the task into the task queue @@ -495,64 +570,32 @@ function spawn(f, args...; kwargs...) return task end -struct FetchAdaptor end -Adapt.adapt_storage(::FetchAdaptor, x::DTask) = fetch(x) -Adapt.adapt_structure(::FetchAdaptor, A::AbstractArray) = - map(x->Adapt.adapt(FetchAdaptor(), x), A) - -""" - Dagger.fetch_all(x) - -Recursively fetches all `DTask`s and `Chunk`s in `x`, returning an equivalent -object. Useful for converting arbitrary Dagger-enabled objects into a -non-Dagger form. -""" -fetch_all(x) = Adapt.adapt(FetchAdaptor(), x) - -persist!(t::Thunk) = (t.persist=true; t) -cache_result!(t::Thunk) = (t.cache=true; t) - -# @generated function compose{N}(f, g, t::NTuple{N}) -# if N <= 4 -# ( :(()->f(g())), -# :((a)->f(g(a))), -# :((a,b)->f(g(a,b))), -# :((a,b,c)->f(g(a,b,c))), -# :((a,b,c,d)->f(g(a,b,c,d))), )[N+1] -# else -# :((xs...) -> f(g(xs...))) -# end -# end - -# function Thunk(f::Function, t::Tuple{Thunk}) -# g = compose(f, t[1].f, t[1].inputs) -# Thunk(g, t[1].inputs) -# end - # this gives a ~30x speedup in hashing Base.hash(x::Thunk, h::UInt) = hash(x.id, hash(h, 0x7ad3bac49089a05f % UInt)) Base.isequal(x::Thunk, y::Thunk) = x.id==y.id function show_thunk(io::IO, t) lvl = get(io, :lazy_level, 0) - f = if t.f isa Chunk - Tf = t.f.chunktype + f = value(first(t.inputs)) + f = if f isa Chunk + Tf = f.chunktype if isdefined(Tf, :instance) Tf.instance else "instance of $Tf" end else - t.f + f end print(io, "Thunk[$(t.id)]($f, ") if lvl > 0 t_inputs = Any[] - for (pos, input) in inputs(t) - if pos === nothing + for arg in inputs(t)[2:end] + input = value(arg) + if ispositional(arg) push!(t_inputs, input) else - push!(t_inputs, pos => input) + push!(t_inputs, pos_kw(arg) => input) end end show(IOContext(io, :lazy_level => lvl-1), t_inputs) diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl new file mode 100644 index 000000000..400b49332 --- /dev/null +++ b/src/utils/chunks.jl @@ -0,0 +1,186 @@ +### Mutation + +function _mutable_inner(@nospecialize(f), proc, scope) + result = f() + return Ref(Dagger.tochunk(result, proc, scope)) +end + +""" + mutable(f::Base.Callable; worker, processor, scope) -> Chunk + +Calls `f()` on the specified worker or processor, returning a `Chunk` +referencing the result with the specified scope `scope`. +""" +function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) + if processor === nothing + if worker === nothing + processor = OSProc() + else + processor = OSProc(worker) + end + else + @assert worker === nothing "mutable: Can't mix worker and processor" + end + if scope === nothing + scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) + end + return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] +end + +""" + @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() + +Helper macro for [`mutable()`](@ref). +""" +macro mutable(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $mutable(f; $(opts...)) + end + end +end + +""" +Maps a value to one of multiple distributed "mirror" values automatically when +used as a thunk argument. Construct using `@shard` or `shard`. +""" +struct Shard + chunks::Dict{Processor,Chunk} +end + +""" + shard(f; kwargs...) -> Chunk{Shard} + +Executes `f` on all workers in `workers`, wrapping the result in a +process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these +`Chunk`s on the current worker. + +Keyword arguments: +- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. +- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. +- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. +""" +function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) + if procs === nothing + if workers !== nothing + procs = [OSProc(w) for w in workers] + else + procs = lock(Sch.eager_context()) do + copy(Sch.eager_context().procs) + end + end + if per_thread + _procs = ThreadProc[] + for p in procs + append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) + end + procs = _procs + end + else + if workers !== nothing + throw(ArgumentError("Cannot combine `procs` and `workers`")) + elseif per_thread + throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) + end + end + isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) + shard_running_dict = Dict{Processor,DTask}() + for proc in procs + scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) + thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) + shard_running_dict[proc] = thunk + end + shard_dict = Dict{Processor,Chunk}() + for proc in procs + shard_dict[proc] = fetch(shard_running_dict[proc])[] + end + return Shard(shard_dict) +end + +"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." +macro shard(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $shard(f; $(opts...)) + end + end +end + +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + # Match either this proc or some ancestor + # N.B. This behavior may bypass the piece's scope restriction + proc = to_proc + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + parent = Dagger.get_parent(proc) + while parent != proc + proc = parent + parent = Dagger.get_parent(proc) + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + end + + throw(KeyError(to_proc)) +end +Base.iterate(s::Shard) = iterate(values(s.chunks)) +Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) +Base.length(s::Shard) = length(s.chunks) + +### Core Stuff + +""" + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk + +Create a chunk from data `x` which resides on `proc` and which has scope +`scope`. + +`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a +`Chunk`) which will be used to manage the reference contained in the `Chunk` +generated by this function. If `device` is `nothing` (the default), the data +will be inspected to determine if it's safe to serialize; if so, the default +MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will +be used. + +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + +All other kwargs are passed directly to `MemPool.poolset`. +""" +function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); device=nothing, rewrap=false, kwargs...) where {X,P,S} + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + ref = poolset(x; device, kwargs...) + Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope) +end +function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end +tochunk(x::Thunk, proc=nothing, scope=nothing; kwargs...) = x + +function savechunk(data, dir, f) + sz = open(joinpath(dir, f), "w") do io + serialize(io, MemPool.MMWrap(data)) + return position(io) + end + fr = FileRef(f, sz) + proc = OSProc() + scope = AnyScope() # FIXME: Scoped to this node + Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) +end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 6a71e5c52..615030400 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,8 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel, - :stream] + :take, :execute, :move, :processor, :finish, + :cancel, :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id @@ -24,18 +24,14 @@ macro dagdebug(thunk, category, msg, args...) $id = -1 end if $id > 0 - if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES + if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES || :all in $DAGDEBUG_CATEGORIES $debug_ex_id end elseif $id == 0 - if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES + if $(QuoteNode(cat_sym)) in $DAGDEBUG_CATEGORIES || :all in $DAGDEBUG_CATEGORIES $debug_ex_noid end end - - # Always yield to reduce differing behavior for debug vs. non-debug - # TODO: Remove this eventually - yield() end end) end diff --git a/src/utils/fetch.jl b/src/utils/fetch.jl new file mode 100644 index 000000000..f15eb07d4 --- /dev/null +++ b/src/utils/fetch.jl @@ -0,0 +1,14 @@ +struct FetchAdaptor end +Adapt.adapt_storage(::FetchAdaptor, x::Chunk) = fetch(x) +Adapt.adapt_storage(::FetchAdaptor, x::DTask) = fetch(x) +Adapt.adapt_structure(::FetchAdaptor, A::AbstractArray) = + map(x->Adapt.adapt(FetchAdaptor(), x), A) + +""" + Dagger.fetch_all(x) + +Recursively fetches all `DTask`s and `Chunk`s in `x`, returning an equivalent +object. Useful for converting arbitrary Dagger-enabled objects into a +non-Dagger form. +""" +fetch_all(x) = Adapt.adapt(FetchAdaptor(), x) diff --git a/src/utils/logging-events.jl b/src/utils/logging-events.jl index 68b80b487..07111e254 100644 --- a/src/utils/logging-events.jl +++ b/src/utils/logging-events.jl @@ -148,7 +148,7 @@ function (ta::TaskArguments)(ev::Event{:finish}) if ev.category == :move args = Pair{Union{Symbol,Int},Dagger.LoggedMutableObject}[] thunk_id = ev.id.thunk_id::Int - pos = ev.id.position::Union{Symbol,Int} + pos = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int} arg = ev.timeline.data if ismutable(arg) push!(args, pos => Dagger.objectid_or_chunkid(arg)) @@ -174,7 +174,7 @@ function (ta::TaskArgumentMoves)(ev::Event{:start}) data = ev.timeline.data if ismutable(data) thunk_id = ev.id.thunk_id::Int - position = ev.id.position::Union{Symbol,Int} + position = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int} d = get!(Dict{Union{Int,Symbol},Dagger.LoggedMutableObject}, ta.pre_move_args, thunk_id) d[position] = Dagger.objectid_or_chunkid(data) end @@ -186,7 +186,7 @@ function (ta::TaskArgumentMoves)(ev::Event{:finish}) post_data = ev.timeline.data if ismutable(post_data) thunk_id = ev.id.thunk_id::Int - position = ev.id.position::Union{Symbol,Int} + position = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int} if haskey(ta.pre_move_args, thunk_id) d = ta.pre_move_args[thunk_id] if haskey(d, position) @@ -246,8 +246,8 @@ function (::TaskDependencies)(ev::Event{:start}) end if ev.category == :add_thunk deps_tids = Int[] - get_deps!(Iterators.filter(Dagger.istask, Iterators.map(last, ev.timeline.args))) - get_deps!(get(Set, ev.timeline.options, :syncdeps)) + get_deps!(Iterators.filter(Dagger.istask, Iterators.map(Dagger.value, ev.timeline.args))) + get_deps!(@something(ev.timeline.options.syncdeps, Set())) return ev.id.thunk_id => deps_tids end return diff --git a/src/utils/logging.jl b/src/utils/logging.jl index b3ec0cc89..1b1f09abd 100644 --- a/src/utils/logging.jl +++ b/src/utils/logging.jl @@ -107,6 +107,15 @@ these logs. """ fetch_logs!() = TimespanLogging.get_logs!(Dagger.Sch.eager_context()) +# Convenience macros to reduce allocations when logging is disabled +macro maybelog(ctx, ex) + quote + if !($(esc(ctx)).log_sink isa $(TimespanLogging.NoOpLog)) + $(esc(ex)) + end + end +end + function logs_event_pairs(f, logs::Dict) running_events = Dict{Tuple,Int}() for w in keys(logs) @@ -133,7 +142,7 @@ end Associates an argument `arg` with `name` in the logs, which logs renderers may utilize for display purposes. """ -function logs_annotate!(ctx::Context, arg, name::Union{String,Symbol}) +function logs_annotate!(ctx#=::Context=#, arg, name::Union{String,Symbol}) ismutable(arg) || throw(ArgumentError("Argument must be mutable to be annotated")) Dagger.TimespanLogging.timespan_start(ctx, :data_annotation, (;objectid=objectid_or_chunkid(arg), name), nothing) Dagger.TimespanLogging.timespan_finish(ctx, :data_annotation, (;objectid=objectid_or_chunkid(arg), name), nothing) diff --git a/src/utils/reuse.jl b/src/utils/reuse.jl new file mode 100644 index 000000000..9515e8e1b --- /dev/null +++ b/src/utils/reuse.jl @@ -0,0 +1,635 @@ +struct ReuseCleanup + done::Base.RefValue{Bool} + f::Function +end +const REUSE_SCOPE_DEFERRED = ScopedValue{Union{Vector{ReuseCleanup},Nothing}}(nothing) +macro reuse_scope(ex) + @assert @capture(ex, function f_(args__) body_ end) + esc(quote + function $f($(args...)) + @with $REUSE_SCOPE_DEFERRED=>Vector{$ReuseCleanup}() begin + try + $body + finally + deferred = $REUSE_SCOPE_DEFERRED[] + @assert deferred !== nothing + for cleanup in deferred + cleanup.done[] || cleanup() + end + end + end + end + end) +end +macro reuse_defer_cleanup(ex) + @gensym cleanup + quote + let $cleanup = $ReuseCleanup(Base.RefValue(false), ()->$(esc(ex))) + push!($REUSE_SCOPE_DEFERRED[], $cleanup) + $cleanup + end + end +end +function (cleanup::ReuseCleanup)() + cleanup.done[] = true + cleanup.f() +end + +struct ReusableCache{T,Tnull} + cache::Vector{T} + used::Vector{Bool} + null::Tnull + sized::Bool + function ReusableCache(T, null, N::Integer; sized::Bool=false) + @assert !Base.datatype_pointerfree(T) "ReusableCache is only useful for non-pointerfree types (got $T)" + #cache = [T() for _ in 1:N] + cache = Vector{T}(undef, N) + used = zeros(Bool, N) + return new{T,typeof(null)}(cache, used, null, sized) + end +end +function maybetake!(cache::ReusableCache{T}, len=nothing) where T + for idx in 1:length(cache.used) + cache.used[idx] && continue + if cache.sized && isassigned(cache.cache, idx) && length(cache.cache[idx]) != len + @dagdebug nothing :reuse "Skipping length $(length(cache.cache[idx])) (want length $len) @ $idx" + continue + end + cache.used[idx] = true + if !isassigned(cache.cache, idx) + if cache.sized + @dagdebug nothing :reuse "Allocating length $len @ $idx" + cache.cache[idx] = alloc!(T, len) + else + cache.cache[idx] = alloc!(T) + end + # Initialize newly allocated object with null value + unset!(cache.cache[idx], cache.null) + end + return (idx, cache.cache[idx]) + end + return nothing +end +function putback!(cache::ReusableCache{T}, idx::Integer) where T + # Check bounds and throw error for invalid indices + if !(1 <= idx <= length(cache.used)) + throw(BoundsError(cache, idx)) + end + + # Reset the cached object to null values before marking as available + if isassigned(cache.cache, idx) + unset!(cache.cache[idx], cache.null) + end + cache.used[idx] = false +end +function take_or_alloc!(f::Function, cache::ReusableCache{T}, len=nothing; no_alloc::Bool=false) where T + idx_value = maybetake!(cache, len) + if idx_value !== nothing + idx, value = idx_value + try + return f(value) + finally + unset!(value, cache.null) + putback!(cache, idx) + end + else + if no_alloc + error("No more entries available in cache for type $T") + end + return f(T()) + end +end +function maybe_take_or_alloc!(f::Function, cache::ReusableCache{T}, value::Union{T,Nothing}, len=nothing; no_alloc::Bool=false) where T + if value !== nothing + return f(value) + else + return take_or_alloc!(f, cache, len; no_alloc=no_alloc) + end +end + +alloc!(::Type{V}, n::Integer) where V<:Vector = V(undef, n) +alloc!(::Type{D}) where {D<:Dict} = D() +alloc!(::Type{S}) where {S<:Set} = S() +alloc!(::Type{T}) where T = T() + +unset!(v::Vector, null) = fill!(v, null) +# FIXME: Inefficient to use these +unset!(d::Dict, _) = empty!(d) +unset!(s::Set, _) = empty!(s) + +macro take_or_alloc!(cache, T, var, ex) + @gensym idx_value idx + quote + $idx_value = $maybetake!($(esc(cache))) + if $idx_value !== nothing + $idx, $(esc(var)) = $idx_value + try + $(esc(ex)) + finally + $unset!($(esc(var)), $(esc(cache)).null) + $putback!($(esc(cache)), $idx) + end + else + #=let=# $(esc(var)) = $(esc(T))() + $(esc(ex)) + #end + end + end +end +macro take_or_alloc!(cache, T, len, var, ex) + @gensym idx_value idx + quote + $idx_value = $maybetake!($(esc(cache)), $(esc(len))) + if $idx_value !== nothing + $idx, $(esc(var)) = $idx_value + try + $(esc(ex)) + finally + $unset!($(esc(var)), $(esc(cache)).null) + $putback!($(esc(cache)), $idx) + end + else + #=let=# $(esc(var)) = $(esc(T))() + $(esc(ex)) + #end + end + end +end +# TODO: const causes issues with Revise +macro reusable(name, T, null, N, var, ex) + cache_name = Symbol("__$(name)_reuse_cache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = $TaskLocalValue{$ReusableCache{$T}}(()->$ReusableCache($T, $null, $N)))) + end + quote + @take_or_alloc! $(esc(cache_name))[] $T $(esc(var)) $(esc(ex)) + end +end +macro reusable(name, T, null, N, len, var, ex) + cache_name = Symbol("__$(name)_reuse_cache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = $TaskLocalValue{$ReusableCache{$T}}(()->$ReusableCache($T, $null, $N; sized=true)))) + end + quote + @take_or_alloc! $(esc(cache_name))[] $T $(esc(len)) $(esc(var)) $(esc(ex)) + end +end + +# FIXME: Provide ReusableObject{T} interface +# FIXME: Allow objects to be GC'd (if lost via throw/unexpected control flow) (provide optional warning mode on finalization) + +#= FIXME: UniquingCache +struct UniquingCache{K,V} + cache::Dict{WeakRef,WeakRef} + function UniquingCache(K, V) + return new(Dict{K,V}()) + end +end +=# + +mutable struct ReusableNode{T} + value::T + next::Union{ReusableNode{T},Nothing} +end +mutable struct ReusableLinkedList{T} <: AbstractVector{T} + head::Union{ReusableNode{T},Nothing} + tail::Union{ReusableNode{T},Nothing} + free_nodes::ReusableNode{T} + null::T + maxlen::Int + function ReusableLinkedList{T}(null, N) where T + free_root = ReusableNode{T}(null, nothing) + for _ in 1:N + free_node = ReusableNode{T}(null, nothing) + free_node.next = free_root + free_root = free_node + end + return new{T}(nothing, nothing, free_root, null, N) + end +end +Base.eltype(list::ReusableLinkedList{T}) where T = T +function Base.getindex(list::ReusableLinkedList{T}, idx::Integer) where T + checkbounds(list, idx) + node = list.head + for _ in 1:(idx-1) + node === nothing && throw(BoundsError(list, idx)) + node = node.next + end + node === nothing && throw(BoundsError(list, idx)) + return node.value +end +function Base.setindex!(list::ReusableLinkedList{T}, value::T, idx::Integer) where T + checkbounds(list, idx) + node = list.head + for _ in 1:(idx-1) + node === nothing && throw(BoundsError(list, idx)) + node = node.next + end + node === nothing && throw(BoundsError(list, idx)) + node.value = value + return value +end +function Base.push!(list::ReusableLinkedList{T}, value) where T + value_conv = convert(T, value) + node = list.free_nodes + if node.next === nothing + # FIXME: Optionally allocate extras + throw(ArgumentError("No more entries available in cache for type $T")) + end + list.free_nodes = node.next + node.value = value_conv + node.next = nothing + if list.head === nothing + list.head = list.tail = node + else + list.tail.next = node + list.tail = node + end + return list +end +function Base.pushfirst!(list::ReusableLinkedList{T}, value) where T + value_conv = convert(T, value) + node = list.free_nodes + if node.next === nothing + # FIXME: Optionally allocate extras + throw(ArgumentError("No more entries available in cache for type $T")) + end + list.free_nodes = node.next + node.value = value_conv + node.next = list.head + list.head = node + if list.tail === nothing + list.tail = node + end + return list +end +function Base.pop!(list::ReusableLinkedList{T}) where T + if list.head === nothing + throw(ArgumentError("list must be non-empty")) + end + prev = node = list.head + while node.next !== nothing + prev = node + node = node.next + end + if prev !== node + list.tail = prev + else + list.head = list.tail = nothing + end + prev.next = nothing + node.next = list.free_nodes + list.free_nodes = node + value = node.value + node.value = list.null + return value +end +function Base.popfirst!(list::ReusableLinkedList{T}) where T + if list.head === nothing + throw(ArgumentError("list must be non-empty")) + end + node = list.head + list.head = node.next + if list.head === nothing + list.tail = nothing + end + node.next = list.free_nodes + list.free_nodes = node + value = node.value + node.value = list.null + return value +end +Base.size(list::ReusableLinkedList{T}) where T = (length(list),) +function Base.length(list::ReusableLinkedList{T}) where T + node = list.head + if node === nothing + return 0 + end + len = 1 + while node.next !== nothing + len += 1 + node = node.next + end + return len +end +function Base.iterate(list::ReusableLinkedList{T}) where T + node = list.head + if node === nothing + return nothing + end + return (node.value, node) +end +function Base.iterate(list::ReusableLinkedList{T}, state::Union{Nothing,ReusableNode{T}}) where T + if state === nothing + return nothing + end + node = state.next + if node === nothing + return nothing + end + return (node.value, node) +end +function Base.in(list::ReusableLinkedList{T}, value::T) where T + node = list.head + while node !== nothing + if node.value == value + return true + end + end + return false +end +function Base.findfirst(f::Function, list::ReusableLinkedList) + node = list.head + idx = 1 + while node !== nothing + if f(node.value) + return idx + end + node = node.next + idx += 1 + end + return nothing +end +Base.sizehint!(list::ReusableLinkedList, len::Integer) = nothing +function Base.empty!(list::ReusableLinkedList{T}) where T + if list.tail !== nothing + fill!(list, list.null) + list.tail.next = list.free_nodes + list.free_nodes = list.head + list.head = list.tail = nothing + end + return list +end +function Base.fill!(list::ReusableLinkedList{T}, value::T) where T + node = list.head + while node !== nothing + node.value = value + node = node.next + end + return list +end +function Base.resize!(list::ReusableLinkedList, N::Integer) + while length(list) < N + push!(list, list.null) + end + while length(list) > N + pop!(list) + end + return list +end +function Base.deleteat!(list::ReusableLinkedList, idx::Integer) + checkbounds(list, idx) + if idx == 1 + deleted = list.head + list.head = list.head.next + deleted.next = list.free_nodes + list.free_nodes = deleted + deleted.value = list.null + return list + end + node = list.head + for _ in 1:(idx-2) + if node === nothing + throw(BoundsError(idx)) + end + node = node.next + end + if idx == length(list) + list.tail = node + end + deleted = node.next + node.next = deleted.next + deleted.next = list.free_nodes + list.free_nodes = deleted + deleted.value = list.null + return list +end +function Base.map!(f, list_out::ReusableLinkedList{T}, list_in::ReusableLinkedList{V}; N=length(list_in)) where {T,V} + node_out = list_out.head + node_in = list_in.head + ctr = 0 + while node_in !== nothing + node_out.value = f(node_in.value) + node_in = node_in.next + node_out = node_out.next + ctr += 1 + if ctr >= N + break + end + end + return list_out +end +function Base.copyto!(list_out::ReusableLinkedList{T}, list_in::ReusableLinkedList{T}) where T + Base.map!(identity, list_out, list_in) +end + +struct ReusableSet{T} <: AbstractSet{T} + list::ReusableLinkedList{T} +end +function ReusableSet(T, null, N) + return ReusableSet{T}(ReusableLinkedList{T}(null, N)) +end +function Base.push!(set::ReusableSet{T}, value::T) where T + if !(value in set) + push!(set.list, value) + end + return set +end +function Base.pop!(set::ReusableSet{T}, value) where T + value_conv = convert(T, value) + idx = findfirst(==(value_conv), set) + if idx === nothing + throw(KeyError(value_conv)) + end + deleteat!(set, idx) + return value +end +Base.length(set::ReusableSet) = length(set.list) +function Base.iterate(set::ReusableSet) + return iterate(set.list) +end +function Base.iterate(set::ReusableSet, state) + return iterate(set.list, state) +end +function Base.empty!(set::ReusableSet{T}) where T + empty!(set.list) + return set +end + +struct ReusableDict{K,V} <: AbstractDict{K,V} + keys::ReusableLinkedList{K} + values::ReusableLinkedList{V} +end +function ReusableDict{K,V}(null_key, null_value, N::Integer) where {K,V} + keys = ReusableLinkedList{K}(null_key, N) + values = ReusableLinkedList{V}(null_value, N) + return ReusableDict{K,V}(keys, values) +end +function Base.getindex(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + throw(KeyError(key_conv)) + end + return dict.values[idx] +end +function Base.setindex!(dict::ReusableDict{K,V}, value, key) where {K,V} + key_conv = convert(K, key) + value_conv = convert(V, value) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + push!(dict.keys, key_conv) + push!(dict.values, value_conv) + else + dict.values[idx] = value_conv + end + return value +end +function Base.delete!(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + throw(KeyError(key_conv)) + end + deleteat!(dict.keys, idx) + deleteat!(dict.values, idx) + return dict +end +function Base.haskey(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + return key_conv in dict.keys +end +function Base.iterate(dict::ReusableDict) + key = dict.keys.head + if key === nothing + return nothing + end + value = dict.values.head + return (key.value => value.value, (key, value)) +end +Base.length(dict::ReusableDict) = length(dict.keys) +function Base.iterate(dict::ReusableDict, state) + if state === nothing + return nothing + end + key, value = state + key = key.next + if key === nothing + return nothing + end + value = value.next + return (key.value => value.value, (key, value)) +end +Base.keys(dict::ReusableDict) = dict.keys +Base.values(dict::ReusableDict) = dict.values +function Base.empty!(dict::ReusableDict{K,V}) where {K,V} + empty!(dict.keys) + empty!(dict.values) + return dict +end + +macro reusable_vector(name, T, null, N) + vec_name = Symbol("__$(name)_TLV_ReusableVector") + if !hasproperty(__module__, vec_name) + __module__.eval(:(#=const=# $vec_name = $TaskLocalValue{$Vector{$T}}(()->$Vector{$T}()))) + end + return :($(esc(vec_name))[]) +end +macro reusable_dict(name, K, V, null_key, null_value, N) + dict_name = Symbol("__$(name)_TLV_ReusableDict") + if !hasproperty(__module__, dict_name) + __module__.eval(:(#=const=# $dict_name = $TaskLocalValue{$Dict{$K,$V}}(()->$Dict{$K,$V}()))) + end + return :($(esc(dict_name))[]) +end + +mutable struct ReusableTaskCache + tasks::Vector{Task} + chans::Vector{Channel{Any}} + ready::Vector{Threads.Atomic{Bool}} + setup_f::Function + N::Int + init::Bool + function ReusableTaskCache(N::Integer) + tasks = Vector{Task}(undef, N) + chans = Vector{Channel{Any}}(undef, N) + ready = [Threads.Atomic{Bool}(true) for _ in 1:N] + for idx in 1:N + chans[idx] = Channel{Any}(1) + chan, r = chans[idx], ready[idx] + tasks[idx] = @task reusable_task_loop(chan, r) + end + cache = new(tasks, chans, ready, t->nothing, N, false) + finalizer(cache) do cache + # Ask tasks to shut down + for idx in 1:N + Threads.atomic_xchg!(cache.ready[idx], false) + close(cache.chans[idx]) + end + end + return cache + end +end +function reusable_task_cache_init!(setup_f::Function, cache::ReusableTaskCache) + cache.init && return + cache.setup_f = setup_f + for idx in 1:cache.N + task = cache.tasks[idx] + setup_f(task) + schedule(task) + Sch.errormonitor_tracked("reusable_task_$idx", task) + end + cache.init = true + return +end +function reusable_task_loop(chan::Channel{Any}, ready::Threads.Atomic{Bool}) + r = rand(1:128) + while true + f = try + take!(chan) + catch + if !isopen(chan) + return + else + rethrow() + end + end + try + @invokelatest f() + catch err + @error "[$r] Error in reusable task" exception=(err, catch_backtrace()) + end + Threads.atomic_xchg!(ready, true) + end +end +function (cache::ReusableTaskCache)(f, name::String) + idx = findfirst(getindex, cache.ready) + if idx !== nothing + @assert Threads.atomic_xchg!(cache.ready[idx], false) + put!(cache.chans[idx], f) + Sch.errormonitor_tracked_set!(name, cache.tasks[idx]) + return cache.tasks[idx] + else + t = @task try + @invokelatest f() + catch err + @error "[$r] Error in non-reusable task" exception=(err, catch_backtrace()) + end + cache.setup_f(t) + schedule(t) + Sch.errormonitor_tracked(name, t) + return t + end + return +end + +macro reusable_tasks(name, N, setup_ex, task_name, task_ex) + cache_name = Symbol("__$(name)_TLV_ReusableTaskCache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = $TaskLocalValue{$ReusableTaskCache}(()->$ReusableTaskCache($N)))) + end + return esc(quote + $reusable_task_cache_init!($setup_ex, $cache_name[]) + $cache_name[]($task_ex, $task_name) + end) +end diff --git a/src/utils/scopes.jl b/src/utils/scopes.jl index cef87a3f0..2b310e4ce 100644 --- a/src/utils/scopes.jl +++ b/src/utils/scopes.jl @@ -1,14 +1,34 @@ # Scope-Processor helpers +""" + get_compute_scope() -> AbstractScope + +Returns the currently set compute scope, first checking the `compute_scope` +option, then checking the `scope` option, and finally defaulting to +`DefaultScope()`. +""" +function get_compute_scope() + opts = get_options() + if hasproperty(opts, :compute_scope) + return opts.compute_scope + elseif hasproperty(opts, :scope) + return opts.scope + else + return DefaultScope() + end +end + """ compatible_processors(scope::AbstractScope, ctx::Context=Sch.eager_context()) -> Set{Processor} Returns the set of all processors (across all Distributed workers) that are compatible with the given scope. """ -function compatible_processors(scope::AbstractScope=get_options(:scope, DefaultScope()), ctx::Context=Sch.eager_context()) +compatible_processors(scope::AbstractScope=get_compute_scope(), ctx::Context=Sch.eager_context()) = + compatible_processors(scope, procs(ctx)) +function compatible_processors(scope::AbstractScope, procs::Vector{<:Processor}) compat_procs = Set{Processor}() - for gproc in procs(ctx) + for gproc in procs # Fast-path in case entire process is incompatible gproc_scope = ProcessScope(gproc) if !isa(constrain(scope, gproc_scope), InvalidScope) @@ -31,7 +51,7 @@ specified, according to `scope`. If `all=true`, instead returns the number of processors known to Dagger, whether or not they've been disabled by the user. Most users will want to use `num_processors()`. """ -function num_processors(scope::AbstractScope=get_options(:scope, DefaultScope()); +function num_processors(scope::AbstractScope=get_compute_scope(); all::Bool=false) if all return length(all_processors()) diff --git a/src/utils/signature.jl b/src/utils/signature.jl new file mode 100644 index 000000000..78a788400 --- /dev/null +++ b/src/utils/signature.jl @@ -0,0 +1,31 @@ +struct Signature + sig::Vector{Any}#DataType} + hash::UInt + sig_nokw::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int}},true} + hash_nokw::UInt + function Signature(sig::Vector{Any})#DataType}) + # Hash full signature + h = hash(Signature) + for T in sig + h = hash(T, h) + end + + # Hash non-kwarg signature + @assert isdefined(Core, :kwcall) "FIXME: No kwcall! Use kwfunc" + idx = findfirst(T->T===typeof(Core.kwcall), sig) + if idx !== nothing + # Skip NT kwargs + sig_nokw = @view sig[idx+2:end] + else + sig_nokw = @view sig[1:end] + end + h_nokw = hash(Signature, UInt(1)) + for T in sig_nokw + h_nokw = hash(T, h_nokw) + end + + return new(sig, h, sig_nokw, h_nokw) + end +end +Base.hash(sig::Signature, h::UInt) = hash(sig.hash, h) +Base.isequal(sig1::Signature, sig2::Signature) = sig1.hash == sig2.hash diff --git a/src/utils/viz.jl b/src/utils/viz.jl index 295fe8a36..14bdc3d02 100644 --- a/src/utils/viz.jl +++ b/src/utils/viz.jl @@ -44,22 +44,24 @@ Requires the following events enabled in `enable_logging!`: `taskdeps`, `tasknam Options: - `disconnected`: If `true`, render disconnected vertices (tasks or arguments without upstream/downstream dependencies) +- `show_data`: If `true`, show the data dependencies in the graph - `color_by`: How to color tasks; if `:fn`, then color by unique function name, if `:proc`, then color by unique processor - `times`: If `true`, annotate each task with its start and finish times - `times_digits`: Number of digits to display in the time annotations - `colors`: A list of colors to use for coloring tasks - `name_to_color`: A function that maps task names to colors """ -function show_logs(io::IO, logs::Dict, ::Val{:graphviz}; disconnected=false, +function show_logs(io::IO, logs::Dict, ::Val{:graphviz}; + disconnected=false, show_data::Bool=true, color_by=:fn, times::Bool=true, times_digits::Integer=3, colors=default_colors, name_to_color=name_to_color) - dot = logs_to_dot(logs; disconnected, times, times_digits, + dot = logs_to_dot(logs; disconnected, show_data, times, times_digits, color_by, colors, name_to_color) println(io, dot) end -function logs_to_dot(logs::Dict; disconnected=false, color_by=:fn, - times::Bool=true, times_digits::Integer=3, +function logs_to_dot(logs::Dict; disconnected=false, show_data::Bool=true, + color_by=:fn, times::Bool=true, times_digits::Integer=3, colors=default_colors, name_to_color=name_to_color) # Lookup all relevant task/argument dependencies and values in logs g = SimpleDiGraph() @@ -325,36 +327,40 @@ function logs_to_dot(logs::Dict; disconnected=false, color_by=:fn, end # Add object vertices - for objid in all_objids - objid_v = objid_to_vertex[objid] - if !disconnected && !(objid_v in con_vs) - continue - end - if objid in dtasks_to_patch || haskey(uid_to_tid, objid) - # DTask, skip it - continue - end - # Object - if haskey(objid_to_name, objid) - label = sanitize_label(objid_to_name[objid]) - label *= "\\nData: $(repr(objid))" - else - label = "Data: $(repr(objid))" + if show_data + for objid in all_objids + objid_v = objid_to_vertex[objid] + if !disconnected && !(objid_v in con_vs) + continue + end + if objid in dtasks_to_patch || haskey(uid_to_tid, objid) + # DTask, skip it + continue + end + # Object + if haskey(objid_to_name, objid) + label = sanitize_label(objid_to_name[objid]) + label *= "\\nData: $(repr(objid))" + else + label = "Data: $(repr(objid))" + end + str *= "a$objid_v [label=\"$label\", shape=oval]\n" end - str *= "a$objid_v [label=\"$label\", shape=oval]\n" end # Add task argument move edges - seen_moves = Set{Tuple{UInt,UInt}}() - for (tid, moves) in task_arg_moves - for (pos, (pre_objid, post_objid)) in moves - pre_objid == post_objid && continue - (pre_objid, post_objid) in seen_moves && continue - push!(seen_moves, (pre_objid, post_objid)) - pre_objid_v = objid_to_vertex[pre_objid] - post_objid_v = objid_to_vertex[post_objid] - move_str = "a$pre_objid_v -> a$post_objid_v [label=\"move\"]\n" - str *= move_str + if show_data + seen_moves = Set{Tuple{UInt,UInt}}() + for (tid, moves) in task_arg_moves + for (pos, (pre_objid, post_objid)) in moves + pre_objid == post_objid && continue + (pre_objid, post_objid) in seen_moves && continue + push!(seen_moves, (pre_objid, post_objid)) + pre_objid_v = objid_to_vertex[pre_objid] + post_objid_v = objid_to_vertex[post_objid] + move_str = "a$pre_objid_v -> a$post_objid_v [label=\"move\"]\n" + str *= move_str + end end end @@ -371,31 +377,33 @@ function logs_to_dot(logs::Dict; disconnected=false, color_by=:fn, str *= "v$(src(edge)) $edge_sep v$(dst(edge)) [label=\"syncdep\"]\n" end - # Add task argument edges - for (tid, args) in task_args - haskey(tid_to_vertex, tid) || continue - tid_v = tid_to_vertex[tid] - tid_v in con_vs || continue - for (pos, arg) in args - arg_v = objid_to_vertex[arg] - if !disconnected && !(arg_v in con_vs) - continue + if show_data + # Add task argument edges + for (tid, args) in task_args + haskey(tid_to_vertex, tid) || continue + tid_v = tid_to_vertex[tid] + tid_v in con_vs || continue + for (pos, arg) in args + arg_v = objid_to_vertex[arg] + if !disconnected && !(arg_v in con_vs) + continue + end + arg_str = sanitize_label(pos isa Int ? "arg $pos" : "kwarg $pos") + str *= "a$arg_v $edge_sep v$tid_v [label=\"$arg_str\"]\n" end - arg_str = sanitize_label(pos isa Int ? "arg $pos" : "kwarg $pos") - str *= "a$arg_v $edge_sep v$tid_v [label=\"$arg_str\"]\n" end - end - # Add task result edges - for (tid, result) in task_result - haskey(tid_to_vertex, tid) || continue - tid_v = tid_to_vertex[tid] - tid_v in con_vs || continue - result_v = objid_to_vertex[result] - if !disconnected && !(result_v in con_vs) - continue + # Add task result edges + for (tid, result) in task_result + haskey(tid_to_vertex, tid) || continue + tid_v = tid_to_vertex[tid] + tid_v in con_vs || continue + result_v = objid_to_vertex[result] + if !disconnected && !(result_v in con_vs) + continue + end + str *= "v$tid_v $edge_sep a$result_v [label=\"result\"]\n" end - str *= "v$tid_v $edge_sep a$result_v [label=\"result\"]\n" end # Generate the final graph diff --git a/test/checkpoint.jl b/test/checkpoint.jl index dee2ffd59..df75dae7a 100644 --- a/test/checkpoint.jl +++ b/test/checkpoint.jl @@ -1,3 +1,5 @@ +import Dagger: Options + @testset "Scheduler Checkpointing" begin ctx = Context([1,workers()...]) d = Ref(false) @@ -86,7 +88,7 @@ end @testset "Thunk Checkpointing" begin ctx = Context([1,workers()...]) d = Ref(false) - opts = Dagger.Sch.ThunkOptions(; + opts = Options(; single=1, checkpoint=(thunk, result)->begin @assert thunk.f != Base.:* @@ -102,7 +104,7 @@ end @test d[] @test f[] == 1 - opts = Dagger.Sch.ThunkOptions(; + opts = Options(; single=1, restore=(thunk)->begin @assert thunk.f != Base.:* @@ -122,7 +124,7 @@ end @testset "Checkpoint Failure" begin d = Ref(false) e = Ref(false) - opts = Dagger.Sch.ThunkOptions(; + opts = Options(; single=1, checkpoint=(thunk, result)->begin e[] = true @@ -142,7 +144,7 @@ end @testset "Restore Failure" begin d = Ref(false) e = Ref(false) - opts = Dagger.Sch.ThunkOptions(; + opts = Options(; single=1, restore=(thunk)->begin e[] = true @@ -161,7 +163,7 @@ end @test e[] # restore executed end @testset "Restore Failure (quiet)" begin - opts = Dagger.Sch.ThunkOptions(; + opts = Options(; single=1, restore=(thunk)->begin nothing diff --git a/test/datadeps.jl b/test/datadeps.jl index 66e41de18..6e7d25b0c 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -71,7 +71,6 @@ end @testset "Aliasing" begin f! = v1 -> begin - @show typeof(v1) v1 v1 .= 0 return end @@ -488,7 +487,7 @@ function test_datadeps(;args_chunks::Bool, [Dagger.@spawn Dagger.task_processor() for i in 1:10] end) unique!(exec_procs) - scope = Dagger.get_options(:scope) + scope = Dagger.get_compute_scope() all_procs = vcat([collect(Dagger.get_processors(OSProc(w))) for w in procs()]...) scope_procs = filter(proc->!isa(Dagger.constrain(scope, ExactScope(proc)), Dagger.InvalidScope), all_procs) for proc in exec_procs diff --git a/test/logging.jl b/test/logging.jl index 75686e4a2..9e321c9c8 100644 --- a/test/logging.jl +++ b/test/logging.jl @@ -121,7 +121,9 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 end end end - @test length(keys(logs)) > 1 + if nprocs() > 1 + @test length(keys(logs)) > 1 + end l1 = logs[1] core = l1[:core] @@ -132,14 +134,6 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 @test any(e->haskey(e, :fire), esat) @test any(e->haskey(e, :take), esat) @test any(e->haskey(e, :finish), esat) - if Threads.nthreads() == 1 - # Note: May one day be true as scheduler evolves - @test !any(e->haskey(e, :compute), esat) - @test !any(e->haskey(e, :move), esat) - psat = l1[:psat] - # Note: May become false - @test all(e->length(e) == 0, psat) - end had_psat_proc = 0 for wo in filter(w->w != 1, keys(logs)) @@ -157,7 +151,9 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 @test any(e->haskey(e, :move), esat) end end - @test had_psat_proc > 0 + if nprocs() > 1 + @test had_psat_proc > 0 + end logs = TimespanLogging.get_logs!(ml) for w in keys(logs) diff --git a/test/memory-spaces.jl b/test/memory-spaces.jl index df4f69905..7e27f78e9 100644 --- a/test/memory-spaces.jl +++ b/test/memory-spaces.jl @@ -3,41 +3,59 @@ # OSProc x = 123 @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end # ThreadProc x = Dagger.tochunk(123) @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + end - x = remotecall_fetch(Dagger.tochunk, 2, 123) - @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + x = remotecall_fetch(Dagger.tochunk, 2, 123) + @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end x = Dagger.@spawn scope=Dagger.scope(worker=1) identity(123) @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + end - x = Dagger.@spawn scope=Dagger.scope(worker=2) identity(123) - @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + x = Dagger.@spawn scope=Dagger.scope(worker=2) identity(123) + @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end end @testset "Processor Queries" begin w1_t1_proc = Dagger.ThreadProc(1,1) w1_t2_proc = Dagger.ThreadProc(1,2) - w2_t1_proc = Dagger.ThreadProc(2,1) - w2_t2_proc = Dagger.ThreadProc(2,2) + if nprocs() > 1 + w2_t1_proc = Dagger.ThreadProc(2,1) + w2_t2_proc = Dagger.ThreadProc(2,2) + end @test Dagger.memory_spaces(w1_t1_proc) == Set([Dagger.CPURAMMemorySpace(1)]) @test Dagger.memory_spaces(w1_t2_proc) == Set([Dagger.CPURAMMemorySpace(1)]) - @test Dagger.memory_spaces(w2_t1_proc) == Set([Dagger.CPURAMMemorySpace(2)]) - @test Dagger.memory_spaces(w2_t2_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + if nprocs() > 1 + @test Dagger.memory_spaces(w2_t1_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + @test Dagger.memory_spaces(w2_t2_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + end @test only(Dagger.memory_spaces(w1_t1_proc)) == only(Dagger.memory_spaces(w1_t2_proc)) - @test only(Dagger.memory_spaces(w2_t1_proc)) != only(Dagger.memory_spaces(w1_t1_proc)) + if nprocs() > 1 + @test only(Dagger.memory_spaces(w2_t1_proc)) != only(Dagger.memory_spaces(w1_t1_proc)) + end @test_throws ArgumentError Dagger.memory_spaces(FakeProc()) w1_mem = Dagger.CPURAMMemorySpace(1) - w2_mem = Dagger.CPURAMMemorySpace(2) @test Set(Dagger.processors(w1_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(1))) - @test Set(Dagger.processors(w2_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(2))) + if nprocs() > 1 + w2_mem = Dagger.CPURAMMemorySpace(2) + @test Set(Dagger.processors(w2_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(2))) + end end end diff --git a/test/mutation.jl b/test/mutation.jl index fa2f62bcf..50e2a3485 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -1,3 +1,5 @@ +import Dagger.Sch: SchedulingException + @everywhere begin struct DynamicHistogram bins::Vector{Float64} @@ -48,7 +50,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap (Dagger.DTaskFailedException, SchedulingException) fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/options.jl b/test/options.jl index 91349ab9f..bce6fa188 100644 --- a/test/options.jl +++ b/test/options.jl @@ -28,9 +28,8 @@ end for (option, default, value, value2) in [ # Special handling (:scope, AnyScope(), ProcessScope(first_wid), ProcessScope(last_wid)), - # ThunkOptions field + # Options field (:single, 0, first_wid, last_wid), - # Thunk field (:meta, false, true, false) ] # Test local and remote default values diff --git a/test/processors.jl b/test/processors.jl index 6e56876dd..fa2f30551 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -1,6 +1,5 @@ using Distributed -import Dagger: Context, Processor, OSProc, ThreadProc, get_parent, get_processors -import Dagger.Sch: ThunkOptions +import Dagger: Context, Options, Processor, OSProc, ThreadProc, get_parent, get_processors @everywhere begin @@ -23,7 +22,7 @@ end unknown_func = () -> nothing tp = ThreadProc(1, 1) op = get_parent(tp) - opts = ThunkOptions() + opts = Options() us = UnknownStruct() for proc in (op, tp) @test Dagger.iscompatible_func(proc, opts, unknown_func) @@ -36,11 +35,11 @@ end @test Dagger.default_enabled(OptOutProc()) == false end @testset "Processor exhaustion" begin - opts = ThunkOptions(proclist=[OptOutProc]) + opts = Options(proclist=[OptOutProc]) @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) - opts = ThunkOptions(proclist=(proc)->false) + opts = Options(proclist=(proc)->false) @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) - opts = ThunkOptions(proclist=nothing) + opts = Options(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end @testset "Roundtrip move()" begin diff --git a/test/reuse.jl b/test/reuse.jl new file mode 100644 index 000000000..5a80df7c9 --- /dev/null +++ b/test/reuse.jl @@ -0,0 +1,1122 @@ +using Test +import Dagger: ReusableLinkedList, ReusableDict, ReusableCache +import Dagger: take_or_alloc!, maybe_take_or_alloc!, maybetake!, putback! + +@testset "ReusableCache Tests" begin + @testset "Construction and Basic Properties" begin + # Test construction with different types and sizes + cache_int = ReusableCache(Vector{Int}, 0, 3) # Use 0 as null value for Int vectors + cache_string = ReusableCache(Vector{String}, "", 5) + cache_dict = ReusableCache(Dict{Int,String}, Dict{Int,String}(), 2) # Use empty dict as null + + # Test sized cache construction + cache_sized = ReusableCache(Vector{Float64}, 0.0, 4; sized=true) + + @test length(cache_int.cache) == 3 + @test length(cache_int.used) == 3 + @test all(cache_int.used .== false) + @test cache_int.null == 0 # Should be 0, not nothing + @test cache_int.sized == false + + @test length(cache_string.cache) == 5 + @test cache_string.null == "" + @test cache_string.sized == false + + @test cache_sized.sized == true + @test cache_sized.null == 0.0 + end + + @testset "maybetake! Basic Operations" begin + cache = ReusableCache(Vector{Int}, 0, 3) # Use 0 as null value + + # Test taking from fresh cache + result1 = maybetake!(cache) + @test result1 !== nothing + idx1, vec1 = result1 + @test idx1 == 1 + @test isa(vec1, Vector{Int}) + @test cache.used[idx1] == true + + # Test taking second entry + result2 = maybetake!(cache) + @test result2 !== nothing + idx2, vec2 = result2 + @test idx2 == 2 + @test cache.used[idx2] == true + + # Test taking third entry + result3 = maybetake!(cache) + @test result3 !== nothing + idx3, vec3 = result3 + @test idx3 == 3 + @test cache.used[idx3] == true + + # Test that cache is exhausted + result4 = maybetake!(cache) + @test result4 === nothing + + # Test that all entries are marked as used + @test all(cache.used) + end + + @testset "putback! Operations" begin + cache = ReusableCache(Vector{String}, "", 2) # Empty string is correct for String vectors + + # Take both entries + result1 = maybetake!(cache) + result2 = maybetake!(cache) + @test result1 !== nothing && result2 !== nothing + idx1, _ = result1 + idx2, _ = result2 + + # Verify cache is exhausted + @test maybetake!(cache) === nothing + + # Put back first entry + putback!(cache, idx1) + @test cache.used[idx1] == false + @test cache.used[idx2] == true + + # Should be able to take one entry again + result3 = maybetake!(cache) + @test result3 !== nothing + idx3, _ = result3 + @test idx3 == idx1 # Should reuse the same slot + + # Put back both entries + putback!(cache, idx2) + putback!(cache, idx3) + @test all(cache.used .== false) + + # Should be able to take both again + @test maybetake!(cache) !== nothing + @test maybetake!(cache) !== nothing + @test maybetake!(cache) === nothing + end + + @testset "Sized Cache Operations" begin + cache = ReusableCache(Vector{Float64}, 0.0, 4; sized=true) + + # Test taking with specific length + result1 = maybetake!(cache, 5) + @test result1 !== nothing + idx1, vec1 = result1 + @test length(vec1) == 5 + putback!(cache, idx1) + + # Test taking with different length - should allocate new slot + result2 = maybetake!(cache, 3) + @test result2 !== nothing + idx2, vec2 = result2 + @test length(vec2) == 3 + @test idx2 != idx1 # Should use different slot + putback!(cache, idx2) + + # Test taking with original length again - should reuse the same slot + result3 = maybetake!(cache, 5) + @test result3 !== nothing + idx3, vec3 = result3 + @test idx3 == idx1 # Should reuse first slot (same length) + @test length(vec3) == 5 + putback!(cache, idx3) + + # Test taking with second length again - should reuse that slot + result4 = maybetake!(cache, 3) + @test result4 !== nothing + idx4, vec4 = result4 + @test idx4 == idx2 # Should reuse second slot (same length) + @test length(vec4) == 3 + putback!(cache, idx4) + + # Allocate slots for different sizes - each will permanently occupy a slot + lengths = [10, 15] # Only allocate 2 more since we already used 2 slots + results = [] + for len in lengths + result = maybetake!(cache, len) + @test result !== nothing + idx, vec = result + @test length(vec) == len + push!(results, (idx, vec, len)) + end + + # Cache should now be full (4 slots all allocated to specific lengths: 5, 3, 10, 15) + @test maybetake!(cache, 20) === nothing # No more slots available + + # Put back the recent allocations + for (idx, _, _) in results + putback!(cache, idx) + end + + # Test that requesting existing lengths still works (reuses allocated slots) + result_reuse_10 = maybetake!(cache, 10) + @test result_reuse_10 !== nothing + idx_reuse_10, vec_reuse_10 = result_reuse_10 + @test length(vec_reuse_10) == 10 + # Should reuse one of the slots we allocated for length 10 + @test idx_reuse_10 in [r[1] for r in results if r[3] == 10] + putback!(cache, idx_reuse_10) + + result_reuse_5 = maybetake!(cache, 5) + @test result_reuse_5 !== nothing + idx_reuse_5, vec_reuse_5 = result_reuse_5 + @test length(vec_reuse_5) == 5 + @test idx_reuse_5 == idx1 # Should reuse the original slot for length 5 + putback!(cache, idx_reuse_5) + + # Verify that we still can't allocate new lengths (cache full with permanent allocations) + @test maybetake!(cache, 25) === nothing + @test maybetake!(cache, 30) === nothing + end + + @testset "take_or_alloc! Operations" begin + cache = ReusableCache(Vector{Int}, 0, 2) # Use 0 as null value for Int vectors + call_count = 0 + + # Function to test with + test_function = function(vec) + call_count += 1 + push!(vec, call_count) + return length(vec) + end + + # Test basic take_or_alloc! + result1 = take_or_alloc!(test_function, cache) + @test result1 == 1 + @test call_count == 1 + + # Cache entry should be returned and cleared + @test all(cache.used .== false) + + # Test second call - vector should be filled with null values + result2 = take_or_alloc!(test_function, cache) + # The vector will be filled with null (0) values from unset!, so after push it will have those + new value + @test call_count == 2 + + # Test with sized cache + sized_cache = ReusableCache(Vector{Float64}, 0.0, 2; sized=true) + sized_function = function(vec) + fill!(vec, 3.14) + return sum(vec) + end + + result3 = take_or_alloc!(sized_function, sized_cache, 5) + @test result3 ≈ 5 * 3.14 + + # Test cache exhaustion with no_alloc=true + cache_small = ReusableCache(Vector{String}, "", 1) + + # Fill the cache + idx, _ = maybetake!(cache_small) + + # This should fail with no_alloc=true + @test_throws ErrorException take_or_alloc!(identity, cache_small; no_alloc=true) + + # Put back and try again - should work + putback!(cache_small, idx) + result4 = take_or_alloc!(x -> length(x), cache_small; no_alloc=true) + @test result4 == 0 # Vector should be empty (newly allocated) + end + + @testset "maybe_take_or_alloc! Operations" begin + cache = ReusableCache(Dict{String,Int}, Dict{String,Int}(), 2) # Use empty dict as null + + test_function = function(dict) + dict["test"] = 42 + return length(dict) + end + + # Test with provided value (should use it directly) + provided_dict = Dict("existing" => 1) + result1 = maybe_take_or_alloc!(test_function, cache, provided_dict) + @test result1 == 2 # Should have both "existing" and "test" + @test provided_dict["test"] == 42 # Original dict should be modified + + # Test with nothing (should fall back to cache) + result2 = maybe_take_or_alloc!(test_function, cache, nothing) + @test result2 == 1 # New dict should only have "test" + + # Test with sized cache and provided value + sized_cache = ReusableCache(Vector{Int}, 0, 2; sized=true) # Use 0 as null value + provided_vec = [1, 2, 3, 4, 5] # Length 5 + + sized_function = function(vec) + vec[1] = 100 + return vec[1] + end + + result3 = maybe_take_or_alloc!(sized_function, sized_cache, provided_vec, 5) + @test result3 == 100 + @test provided_vec[1] == 100 # Original vector modified + + # Test fallback to cache with length + result4 = maybe_take_or_alloc!(sized_function, sized_cache, nothing, 3) + @test result4 == 100 # Should work with new vector of length 3 + end + + @testset "Different Element Types" begin + # Test with various data types + + # String vectors + string_cache = ReusableCache(Vector{String}, "", 2) + string_result = take_or_alloc!(string_cache) do vec + # Vector is initialized with null values ("") + push!(vec, "hello", "world") + join(vec, " ") + end + @test string_result == "hello world" + + # Test reuse - vector should be filled with null values + string_result2 = take_or_alloc!(string_cache) do vec + # Vector should be filled with "" (null value), we'll clear it first for clean output + empty!(vec) + push!(vec, "foo", "bar") + join(vec, " ") + end + @test string_result2 == "foo bar" + + # Dictionaries with complex types + dict_cache = ReusableCache(Dict{Symbol, Vector{Int}}, Dict{Symbol, Vector{Int}}(), 2) + dict_result = take_or_alloc!(dict_cache) do dict + dict[:numbers] = [1, 2, 3] + dict[:more] = [4, 5] + sum(length(v) for v in values(dict)) + end + @test dict_result == 5 + + # Sets + set_cache = ReusableCache(Set{String}, Set{String}(), 3) + set_result = take_or_alloc!(set_cache) do set + push!(set, "a", "b", "c", "a") # Duplicate should be ignored + length(set) + end + @test set_result == 3 + + # Custom struct (simple case) + mutable struct TestStruct + x::Int + y::String + end + TestStruct() = TestStruct(0, "") + function Dagger.unset!(t::TestStruct, tnull::TestStruct) + t.x = tnull.x + t.y = tnull.y + end + function Dagger.alloc!(::Type{TestStruct}) + return TestStruct() + end + + struct_cache = ReusableCache(TestStruct, TestStruct(0, ""), 2) + struct_result = take_or_alloc!(struct_cache) do obj + obj.x = 42 + obj.y = "test" + obj.x + length(obj.y) + end + @test struct_result == 46 + end + + @testset "Cache State Integrity" begin + cache = ReusableCache(Vector{Int}, 0, 3) # Use 0 as null value for Int vectors + + # Test that used array stays consistent + @test length(cache.used) == length(cache.cache) + + # Take all entries and verify used array + indices = [] + for i in 1:3 + result = maybetake!(cache) + @test result !== nothing + idx, _ = result + push!(indices, idx) + @test cache.used[idx] == true + end + + # Verify all unique indices + @test length(unique(indices)) == 3 + @test Set(indices) == Set(1:3) + + # Put back in random order + for idx in shuffle(indices) + putback!(cache, idx) + @test cache.used[idx] == false + end + + # Verify all entries are available again + @test all(cache.used .== false) + for i in 1:3 + @test maybetake!(cache) !== nothing + end + @test maybetake!(cache) === nothing + end + + @testset "Cache Never Enters Broken State" begin + cache = ReusableCache(Vector{Float64}, NaN, 2) # NaN is appropriate for Float64 vectors + + # Test multiple putback of same valid index (this can happen in error handling) + result = maybetake!(cache) + @test result !== nothing + idx, vec = result + @test isa(vec, Vector{Float64}) + + # Put back once + putback!(cache, idx) + @test cache.used[idx] == false + + # Put back again - should not break anything (idempotent) + @test_nowarn putback!(cache, idx) + @test cache.used[idx] == false + + # Should still be able to take + result2 = maybetake!(cache) + @test result2 !== nothing + idx2, vec2 = result2 + @test isa(vec2, Vector{Float64}) + putback!(cache, idx2) + + # Test that cache works normally after edge case operations + result3 = maybetake!(cache) + @test result3 !== nothing + idx3, vec3 = result3 + @test isa(vec3, Vector{Float64}) + + # Modify vector and put back + resize!(vec3, 3) + fill!(vec3, 1.5) + putback!(cache, idx3) + + # Take again and verify it was properly reset + result4 = maybetake!(cache) + @test result4 !== nothing + idx4, vec4 = result4 + @test idx4 == idx3 # Should reuse same slot + @test length(vec4) == 3 + @test all(isnan, vec4) # Should be filled with NaN (null value) + putback!(cache, idx4) + end + + @testset "Cache Does Not Generate Invalid Values" begin + # Test with many operations to ensure consistency + cache = ReusableCache(Vector{Int}, -999, 5) # Use -999 as null value for Int vectors + + # Test that freshly allocated vectors are empty (for non-sized cache) + result = maybetake!(cache) + @test result !== nothing + idx, vec = result + @test isa(vec, Vector{Int}) + @test isempty(vec) # Should start empty for non-sized cache + putback!(cache, idx) + + # Test basic put/take cycle with proper filling + result = maybetake!(cache) + @test result !== nothing + idx, vec = result + + # User fills vector properly + resize!(vec, 5) + fill!(vec, 42) + @test all(x -> x == 42, vec) + + # Put back - should be filled with null values + putback!(cache, idx) + + # Take again - should have null values + result2 = maybetake!(cache) + @test result2 !== nothing + idx2, vec2 = result2 + @test idx2 == idx # Should reuse same slot + @test length(vec2) == 5 + @test all(x -> x == -999, vec2) # Should be filled with null values + + putback!(cache, idx2) + + # Test what happens with empty vectors + result3 = maybetake!(cache) + @test result3 !== nothing + idx3, vec3 = result3 + + # User empties the vector + empty!(vec3) + @test isempty(vec3) + + # Put back empty vector + putback!(cache, idx3) + + # Take again - should still be empty (unset! on empty vector does nothing) + result4 = maybetake!(cache) + @test result4 !== nothing + idx4, vec4 = result4 + @test idx4 == idx3 # Should reuse same slot + @test isempty(vec4) # Should still be empty + + putback!(cache, idx4) + + # Test cache state integrity after many operations + for iteration in 1:50 + # Take random number of entries + taken = [] + n_take = rand(1:3) # Take 1-3 entries + + for _ in 1:n_take + result = maybetake!(cache) + if result !== nothing + idx, vec = result + @test isa(vec, Vector{Int}) + + # Use the vector in a controlled way + resize!(vec, 3) + fill!(vec, iteration) # Properly fill all elements + @test all(x -> x == iteration, vec) + + push!(taken, idx) + end + end + + # Put back all taken entries + for idx in taken + putback!(cache, idx) + end + end + + # Final verification - all vectors should be properly reset + final_taken = [] + for _ in 1:5 + result = maybetake!(cache) + if result !== nothing + idx, vec = result + @test isa(vec, Vector{Int}) + # Vector should contain only null values if not empty + if !isempty(vec) + @test all(x -> x == -999, vec) + end + push!(final_taken, idx) + end + end + + @test maybetake!(cache) === nothing + @test length(final_taken) == 5 + @test Set(final_taken) == Set(1:5) + end + + @testset "Sized Cache Length Consistency" begin + cache = ReusableCache(Vector{Bool}, false, 4; sized=true) + + # Take with different lengths and verify + lengths_and_indices = [] + for len in [3, 7, 3, 5] # Note: duplicate length 3 + result = maybetake!(cache, len) + @test result !== nothing + idx, vec = result + @test length(vec) == len + push!(lengths_and_indices, (len, idx)) + end + + # Cache should be full + @test maybetake!(cache, 10) === nothing + + # Put back all entries + for (_, idx) in lengths_and_indices + putback!(cache, idx) + end + + # Request length 3 again - should reuse existing entry + result1 = maybetake!(cache, 3) + @test result1 !== nothing + idx1, vec1 = result1 + @test length(vec1) == 3 + + # The index should be one of the indices that had length 3 + length_3_indices = [idx for (len, idx) in lengths_and_indices if len == 3] + @test idx1 in length_3_indices + + putback!(cache, idx1) + + # Request different length - should use different slot or allocate new + result2 = maybetake!(cache, 7) + @test result2 !== nothing + idx2, vec2 = result2 + @test length(vec2) == 7 + end +end + +@testset "ReusableLinkedList Tests" begin + + @testset "Construction and Basic Properties" begin + # Test construction with various types + list_int = ReusableLinkedList{Int}(0, 5) + list_str = ReusableLinkedList{String}("", 3) + list_float = ReusableLinkedList{Float64}(0.0, 10) + + # Test initial state + @test length(list_int) == 0 + @test length(list_str) == 0 + @test length(list_float) == 0 + + # Test empty list doesn't contain any values + @test !in(0, list_int) + @test !in("test", list_str) + @test !in(1.0, list_float) + + # Test iteration over empty list + count = 0 + for item in list_int + count += 1 + end + @test count == 0 + end + + @testset "Invalid Operations on Empty Lists" begin + list = ReusableLinkedList{Int}(0, 5) + + # Test invalid index access + @test_throws BoundsError list[1] + @test_throws BoundsError list[0] + @test_throws BoundsError list[-1] + @test_throws BoundsError list[10] + + # Test invalid setindex on empty list + @test_throws BoundsError list[1] = 42 + + # Test pop operations on empty list + @test_throws ArgumentError pop!(list) + @test_throws ArgumentError popfirst!(list) + + # Test deleteat! on empty list + @test_throws BoundsError deleteat!(list, 1) + + # Test findfirst on empty list (should return nothing) + @test findfirst(x -> x == 5, list) === nothing + + # Test empty! on already empty list (should not error) + @test_nowarn empty!(list) + @test length(list) == 0 + + # Test resize! to 0 on empty list + @test_nowarn resize!(list, 0) + @test length(list) == 0 + end + + @testset "Push and Pop Operations" begin + list = ReusableLinkedList{Int}(0, 5) + + # Test push! operations + push!(list, 1) + @test length(list) == 1 + @test list[1] == 1 + + push!(list, 2) + push!(list, 3) + @test length(list) == 3 + @test list[1] == 1 + @test list[2] == 2 + @test list[3] == 3 + + # Test pushfirst! operations + pushfirst!(list, 0) + @test length(list) == 4 + @test list[1] == 0 + @test list[2] == 1 + @test list[3] == 2 + @test list[4] == 3 + + # Test capacity limit + push!(list, 4) # Should reach capacity + @test length(list) == 5 + @test_throws ArgumentError push!(list, 5) # Should exceed capacity + @test_throws ArgumentError pushfirst!(list, -1) # Should exceed capacity + + # Test pop! operations + val = pop!(list) + @test val == 4 + @test length(list) == 4 + + val = popfirst!(list) + @test val == 0 + @test length(list) == 3 + @test list[1] == 1 + @test list[2] == 2 + @test list[3] == 3 + end + + @testset "Indexing Operations" begin + list = ReusableLinkedList{String}("", 4) + push!(list, "a") + push!(list, "b") + push!(list, "c") + + # Test getindex + @test list[1] == "a" + @test list[2] == "b" + @test list[3] == "c" + + # Test invalid getindex + @test_throws BoundsError list[0] + @test_throws BoundsError list[4] + @test_throws BoundsError list[-1] + + # Test setindex! + list[2] = "modified" + @test list[2] == "modified" + @test list[1] == "a" # Other elements unchanged + @test list[3] == "c" + + # Test invalid setindex! + @test_throws BoundsError list[0] = "invalid" + @test_throws BoundsError list[4] = "invalid" + @test_throws BoundsError list[-1] = "invalid" + end + + @testset "Search and Membership" begin + list = ReusableLinkedList{Int}(0, 6) + for i in [10, 20, 30, 20, 40] + push!(list, i) + end + + # Test in operation + @test in(10, list) + @test in(20, list) + @test in(30, list) + @test in(40, list) + @test !in(50, list) + @test !in(0, list) # null element not in list + + # Test findfirst + @test findfirst(x -> x == 10, list) == 1 + @test findfirst(x -> x == 20, list) == 2 # First occurrence + @test findfirst(x -> x == 40, list) == 5 + @test findfirst(x -> x == 50, list) === nothing + @test findfirst(x -> x > 25, list) == 3 # First element > 25 + end + + @testset "Iteration" begin + list = ReusableLinkedList{Float64}(0.0, 4) + values = [1.1, 2.2, 3.3, 4.4] + for v in values + push!(list, v) + end + + # Test basic iteration + collected = Float64[] + for item in list + push!(collected, item) + end + @test collected == values + + # Test iteration with enumerate + for (i, item) in enumerate(list) + @test item == values[i] + end + + # Test collect + @test collect(list) == values + end + + @testset "Bulk Operations" begin + list = ReusableLinkedList{Int}(0, 10) + + # Test fill! + resize!(list, 5) + fill!(list, 42) + @test length(list) == 5 + for i in 1:5 + @test list[i] == 42 + end + + # Test empty! and refill + empty!(list) + @test length(list) == 0 + + # Test resize! with various sizes + resize!(list, 3) + @test length(list) == 3 + + # Fill with test data + for i in 1:3 + list[i] = i * 10 + end + + # Test resize! to larger size + resize!(list, 6) + @test length(list) == 6 + @test list[1] == 10 + @test list[2] == 20 + @test list[3] == 30 + # New elements should be initialized to null value + + # Test resize! to smaller size + resize!(list, 2) + @test length(list) == 2 + @test list[1] == 10 + @test list[2] == 20 + + # Test resize! beyond capacity + @test_throws ArgumentError resize!(list, 11) + end + + @testset "Deletion Operations" begin + list = ReusableLinkedList{Int}(0, 8) + for i in 1:5 + push!(list, i * 10) + end + + # Test deleteat! from middle + deleteat!(list, 3) # Remove 30 + @test length(list) == 4 + @test list[1] == 10 + @test list[2] == 20 + @test list[3] == 40 # 40 moved to position 3 + @test list[4] == 50 + + # Test deleteat! from beginning + deleteat!(list, 1) + @test length(list) == 3 + @test list[1] == 20 + @test list[2] == 40 + @test list[3] == 50 + + # Test deleteat! from end + deleteat!(list, 3) + @test length(list) == 2 + @test list[1] == 20 + @test list[2] == 40 + + # Test invalid deleteat! + @test_throws BoundsError deleteat!(list, 0) + @test_throws BoundsError deleteat!(list, 3) + @test_throws BoundsError deleteat!(list, -1) + end + + @testset "Map and Copy Operations" begin + list = ReusableLinkedList{Int}(0, 6) + for i in 1:4 + push!(list, i) + end + + # Test map! + map!(x -> x * 2, list, list) + @test list[1] == 2 + @test list[2] == 4 + @test list[3] == 6 + @test list[4] == 8 + + # Test copyto! with array + src = [100, 200, 300] + copyto!(list, src) + @test length(list) == 4 + @test list[1] == 100 + @test list[2] == 200 + @test list[3] == 300 + @test list[4] == 8 + + # Test copyto! with another list + list2 = ReusableLinkedList{Int}(0, 6) + push!(list2, 999) + push!(list2, 888) + + copyto!(list, list2) + @test length(list) == 4 + @test list[1] == 999 + @test list[2] == 888 + @test list[3] == 300 + @test list[4] == 8 + + # Test copyto! exceeding capacity + large_src = collect(1:10) + list_small = ReusableLinkedList{Int}(0, 5) + @test_throws BoundsError copyto!(list_small, large_src) + end +end + +@testset "ReusableDict Tests" begin + + @testset "Construction and Basic Properties" begin + # Test construction with various types + dict_int = ReusableDict{String, Int}("", 0, 5) + dict_str = ReusableDict{Int, String}(0, "", 3) + dict_float = ReusableDict{String, Float64}("", 0.0, 10) + + # Test initial state + @test length(dict_int) == 0 + @test length(dict_str) == 0 + @test length(dict_float) == 0 + + # Test empty dict doesn't contain any keys + @test !haskey(dict_int, "test") + @test !haskey(dict_str, 1) + @test !haskey(dict_float, "pi") + + # Test iteration over empty dict + count = 0 + for (k, v) in dict_int + count += 1 + end + @test count == 0 + + # Test keys and values on empty dict + @test length(keys(dict_int)) == 0 + @test length(values(dict_int)) == 0 + end + + @testset "Invalid Operations on Empty Dicts" begin + dict = ReusableDict{String, Int}("", 0, 5) + + # Test invalid key access + @test_throws KeyError dict["nonexistent"] + + # Test delete! on non-existent key + @test_throws KeyError delete!(dict, "nonexistent") + + # Test empty! on already empty dict (should not error) + @test_nowarn empty!(dict) + @test length(dict) == 0 + end + + @testset "Basic Dictionary Operations" begin + dict = ReusableDict{String, Int}("", 0, 4) + + # Test setindex! and getindex + dict["a"] = 1 + @test dict["a"] == 1 + @test length(dict) == 1 + @test haskey(dict, "a") + + dict["b"] = 2 + dict["c"] = 3 + @test dict["b"] == 2 + @test dict["c"] == 3 + @test length(dict) == 3 + + # Test updating existing key + dict["a"] = 10 + @test dict["a"] == 10 + @test length(dict) == 3 # Length shouldn't change + + # Test capacity limit + dict["d"] = 4 + @test length(dict) == 4 + @test_throws ArgumentError dict["e"] = 5 # Should exceed capacity + + # Test getindex with non-existent key + @test_throws KeyError dict["nonexistent"] + end + + @testset "Key Membership and Search" begin + dict = ReusableDict{Int, String}(0, "", 5) + dict[10] = "ten" + dict[20] = "twenty" + dict[30] = "thirty" + + # Test haskey + @test haskey(dict, 10) + @test haskey(dict, 20) + @test haskey(dict, 30) + @test !haskey(dict, 40) + @test !haskey(dict, 0) # null key not in dict + + # Test that null values don't interfere + dict[40] = "" # Setting to null value should still work + @test haskey(dict, 40) + @test dict[40] == "" + end + + @testset "Keys and Values" begin + dict = ReusableDict{String, Float64}("", 0.0, 6) + test_data = Dict("pi" => 3.14, "e" => 2.71, "phi" => 1.61) + + for (k, v) in test_data + dict[k] = v + end + + # Test keys() + dict_keys = collect(keys(dict)) + @test length(dict_keys) == 3 + for k in ["pi", "e", "phi"] + @test k in dict_keys + end + + # Test values() + dict_values = collect(values(dict)) + @test length(dict_values) == 3 + for v in [3.14, 2.71, 1.61] + @test v in dict_values + end + + # Test keys and values are consistent + for k in keys(dict) + @test haskey(test_data, k) + @test dict[k] == test_data[k] + end + end + + @testset "Iteration" begin + dict = ReusableDict{Int, String}(0, "", 4) + test_pairs = [(1, "one"), (2, "two"), (3, "three")] + + for (k, v) in test_pairs + dict[k] = v + end + + # Test iteration over key-value pairs + collected_pairs = [] + for (k, v) in dict + push!(collected_pairs, (k, v)) + end + @test length(collected_pairs) == 3 + + for pair in test_pairs + @test pair in collected_pairs + end + + # Test that iteration order is consistent with keys/values + keys_iter = collect(keys(dict)) + values_iter = collect(values(dict)) + pairs_iter = collect(dict) + + for i in 1:length(pairs_iter) + k, v = pairs_iter[i] + @test k == keys_iter[i] + @test v == values_iter[i] + @test dict[k] == v + end + end + + @testset "Deletion Operations" begin + dict = ReusableDict{String, Int}("", 0, 6) + + # Add test data + for (k, v) in [("a", 1), ("b", 2), ("c", 3), ("d", 4)] + dict[k] = v + end + @test length(dict) == 4 + + # Test delete! existing key + delete!(dict, "b") + @test length(dict) == 3 + @test !haskey(dict, "b") + @test haskey(dict, "a") + @test haskey(dict, "c") + @test haskey(dict, "d") + + # Test delete! non-existent key + @test_throws KeyError delete!(dict, "nonexistent") + @test_throws KeyError delete!(dict, "b") # Already deleted + + # Test that we can add new keys after deletion + dict["e"] = 5 + @test length(dict) == 4 + @test dict["e"] == 5 + + # Test deleting all remaining keys + for k in ["a", "c", "d", "e"] + delete!(dict, k) + end + @test length(dict) == 0 + end + + @testset "Empty! Operation" begin + dict = ReusableDict{Int, String}(0, "", 5) + + # Add some data + for i in 1:4 + dict[i] = "value$i" + end + @test length(dict) == 4 + + # Test empty! + empty!(dict) + @test length(dict) == 0 + + # Test that all keys are gone + for i in 1:4 + @test !haskey(dict, i) + end + + # Test that we can add new data after empty! + dict[100] = "new value" + @test length(dict) == 1 + @test dict[100] == "new value" + + # Test empty! on already empty dict + empty!(dict) + @test_nowarn empty!(dict) # Should not error + @test length(dict) == 0 + end + + @testset "Capacity and Edge Cases" begin + dict = ReusableDict{Int, Int}(0, 0, 3) + + # Fill to capacity + dict[1] = 10 + dict[2] = 20 + dict[3] = 30 + @test length(dict) == 3 + + # Test exceeding capacity + @test_throws ArgumentError dict[4] = 40 + + # Test that updating existing keys doesn't exceed capacity + @test_nowarn dict[1] = 100 + @test dict[1] == 100 + @test length(dict) == 3 + + # Test with null values and keys + dict_nulls = ReusableDict{Int, String}(-1, "NULL", 3) + dict_nulls[0] = "zero" # Non-null key, non-null value + dict_nulls[1] = "NULL" # Non-null key, null value + @test length(dict_nulls) == 2 + @test dict_nulls[0] == "zero" + @test dict_nulls[1] == "NULL" + @test haskey(dict_nulls, 0) + @test haskey(dict_nulls, 1) + @test !haskey(dict_nulls, -1) # null key should not be present + end + + @testset "Type Consistency" begin + # Test with different key-value type combinations + dict_si = ReusableDict{String, Int}("", 0, 3) + dict_is = ReusableDict{Int, String}(0, "", 3) + dict_ff = ReusableDict{Float64, Float64}(0.0, 0.0, 3) + + # Test type enforcement + dict_si["test"] = 42 + @test dict_si["test"] == 42 + + dict_is[42] = "test" + @test dict_is[42] == "test" + + dict_ff[3.14] = 2.71 + @test dict_ff[3.14] == 2.71 + + # Ensure operations maintain type consistency + @test typeof(collect(keys(dict_si))[1]) == String + @test typeof(collect(values(dict_si))[1]) == Int + @test typeof(collect(keys(dict_is))[1]) == Int + @test typeof(collect(values(dict_is))[1]) == String + end +end + +@testset "Cross-Structure Interaction Tests" begin + @testset "Using Structures Together" begin + # Test using both structures in combination + list = ReusableLinkedList{String}("", 5) + dict = ReusableDict{String, Int}("", 0, 5) + + # Add data to list + for word in ["apple", "banana", "cherry"] + push!(list, word) + end + + # Create dictionary with list contents as keys + for (i, word) in enumerate(list) + dict[word] = i + end + + @test length(dict) == 3 + @test dict["apple"] == 1 + @test dict["banana"] == 2 + @test dict["cherry"] == 3 + + # Test that modifications to one don't affect the other + list[1] = "apricot" + @test dict["apple"] == 1 # Dict unchanged + @test !haskey(dict, "apricot") # New value not in dict + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 39860cf2d..264eb4603 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,9 +35,10 @@ tests = [ ("Array - Stencils", "array/stencil.jl"), ("Array - FFT", "array/fft.jl"), ("GPU", "gpu.jl"), - ("Caching", "cache.jl"), + #("Caching", "cache.jl"), ("Disk Caching", "diskcaching.jl"), ("File IO", "file-io.jl"), + ("Reusable Data Structures", "reuse.jl"), ("External Languages - Python", "extlang/python.jl"), ("Preferences", "preferences.jl"), #("Fault Tolerance", "fault-tolerance.jl"), @@ -68,9 +69,15 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ @eval begin @add_arg_table! s begin "--test" - nargs = '*' - default = all_test_names + nargs = 1 + action = :append_arg + arg_type = String help = "Enables the specified test to run in the testsuite" + "--no-test" + nargs = 1 + action = :append_arg + arg_type = String + help = "Disables the specified test from running in the testsuite" "-s", "--simulate" action = :store_true help = "Don't actually run the tests" @@ -89,16 +96,39 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args = parse_args(s) to_test = String[] - for test in parsed_args["test"] + if isempty(parsed_args["test"]) + to_test = copy(all_test_names) + else + for _test in parsed_args["test"] + test = only(_test) + if isdir(joinpath(@__DIR__, test)) + for (_, other_test) in tests + if startswith(other_test, test) + push!(to_test, other_test) + end + end + elseif test in all_test_names + push!(to_test, test) + else + println(stderr, "Unknown test: $test") + println(stderr, "Available tests:") + for ((test_title, _), test_name) in zip(tests, all_test_names) + println(stderr, " $test_name: $test_title") + end + exit(1) + end + end + end + for _test in parsed_args["no-test"] + test = only(_test) if isdir(joinpath(@__DIR__, test)) for (_, other_test) in tests if startswith(other_test, test) - push!(to_test, other_test) - continue + filter!(x -> x != other_test, to_test) end end elseif test in all_test_names - push!(to_test, test) + filter!(x -> x != test, to_test) else println(stderr, "Unknown test: $test") println(stderr, "Available tests:") diff --git a/test/scheduler.jl b/test/scheduler.jl index 9f00485a8..9bd2d159d 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -1,5 +1,5 @@ -import Dagger: Chunk -import Dagger.Sch: SchedulerOptions, ThunkOptions, SchedulerHaltedException, ComputeState, ThunkID, sch_handle +import Dagger: Chunk, Options +import Dagger.Sch: SchedulerOptions, SchedulerHaltedException, ComputeState, ThunkID, sch_handle @everywhere begin using Dagger @@ -142,18 +142,16 @@ end end end @testset "single worker" begin - options = ThunkOptions(;single=1) + options = Options(;single=1) a = delayed(checkwid; options=options)(1) @test collect(Context([1,workers()...]), a) == 1 end - @static if VERSION >= v"1.3.0-DEV.573" - @testset "proclist" begin - options = ThunkOptions(;proclist=[Dagger.ThreadProc]) - a = delayed(checktid; options=options)(1) + @testset "proclist" begin + options = Options(;proclist=[Dagger.ThreadProc]) + a = delayed(checktid; options=options)(1) - @test collect(Context(), a) == 1 - end + @test collect(Context(), a) == 1 end @everywhere Dagger.add_processor_callback!(()->FakeProc(), :fakeproc) @testset "proclist FakeProc" begin @@ -163,8 +161,10 @@ end @test Dagger.default_enabled(Dagger.ThreadProc(1,1)) == true @test Dagger.default_enabled(FakeProc()) == false - as = [delayed(identity; proclist=[Dagger.ThreadProc])(i) for i in 1:5] - b = delayed(fakesum; proclist=[FakeProc], compute_scope=Dagger.AnyScope())(as...) + opts = Options(;proclist=[Dagger.ThreadProc]) + as = [delayed(identity; options=opts)(i) for i in 1:5] + opts = Options(;proclist=[FakeProc], compute_scope=Dagger.AnyScope()) + b = delayed(fakesum; options=opts)(as...) @test collect(Context(), b) == FakeVal(57) end @@ -178,11 +178,6 @@ end collect(b) end =# - @testset "allow errors" begin - opts = ThunkOptions(;allow_errors=true) - a = delayed(error; options=opts)("Test") - @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) collect(a) - end end @testset "Modify workers in running job" begin @@ -208,141 +203,146 @@ end # until we end up on a non-blocked worker h = Dagger.Sch.sch_handle() wkrs = Dagger.Sch.exec!(_list_workers, h) - id = Dagger.Sch.add_thunk!(testfun, h, nothing=>i) - return fetch(h, id) + t = Dagger.Sch.add_thunk!(testfun, h, nothing=>i) + return fetch(t) end return myid() end end - @testset "Add new workers" begin - ps = [] - try - ps1 = addprocs(2, exeflags="--project") - append!(ps, ps1) + if nprocs() > 1 # Skip if we've disabled workers + @test_skip "Add new workers" + #= + @testset "Add new workers" begin + ps = [] + try + ps1 = addprocs(2, exeflags="--project") + append!(ps, ps1) - @everywhere vcat(ps1, myid()) $setup + @everywhere vcat(ps1, myid()) $setup - ctx = Context(ps1) - ts = delayed(vcat)((delayed(testfun)(i) for i in 1:10)...) + ctx = Context(ps1) + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:10)...) - job = @async collect(ctx, ts) + job = @async collect(ctx, ts) - while !istaskstarted(job) - sleep(0.001) - end + while !istaskstarted(job) + sleep(0.001) + end - # Will not be added, so they should never appear in output - ps2 = addprocs(2, exeflags="--project") - append!(ps, ps2) + # Will not be added, so they should never appear in output + ps2 = addprocs(2, exeflags="--project") + append!(ps, ps2) - ps3 = addprocs(2, exeflags="--project") - append!(ps, ps3) - @everywhere ps3 $setup - addprocs!(ctx, ps3) - @test length(procs(ctx)) == 4 + ps3 = addprocs(2, exeflags="--project") + append!(ps, ps3) + @everywhere ps3 $setup + addprocs!(ctx, ps3) + @test length(procs(ctx)) == 4 - @everywhere ps3 blocked=false + @everywhere ps3 blocked=false - ps_used = fetch(job) - @test ps_used isa Vector + ps_used = fetch(job) + @test ps_used isa Vector - @test any(p -> p in ps_used, ps1) - @test any(p -> p in ps_used, ps3) - @test !any(p -> p in ps2, ps_used) - finally - wait(rmprocs(ps)) + @test any(p -> p in ps_used, ps1) + @test any(p -> p in ps_used, ps3) + @test !any(p -> p in ps2, ps_used) + finally + wait(rmprocs(ps)) + end end - end + =# + + @test_skip "Remove workers" + #=@testset "Remove workers" begin + ps = [] + try + ps1 = addprocs(4, exeflags="--project") + append!(ps, ps1) + + @everywhere vcat(ps1, myid()) $setup + + # Use single to force scheduler to make use of all workers since we assert it below + ts = delayed(vcat)((delayed(testfun; single=ps1[mod1(i, end)])(i) for i in 1:10)...) + + # Use FilterLog as a callback function. + nprocs_removed = Ref(0) + first_rescheduled_thunk=Ref(false) + rmproctrigger = Dagger.FilterLog(Dagger.NoOpLog()) do event + if typeof(event) == Dagger.Event{:finish} && event.category === :cleanup_proc + nprocs_removed[] += 1 + end + if typeof(event) == Dagger.Event{:start} && event.category === :add_thunk + first_rescheduled_thunk[] = true + end + return false + end - @test_skip "Remove workers" - #=@testset "Remove workers" begin - ps = [] - try - ps1 = addprocs(4, exeflags="--project") - append!(ps, ps1) + ctx = Context(ps1; log_sink=rmproctrigger) + job = @async collect(ctx, ts) - @everywhere vcat(ps1, myid()) $setup + # Must wait for this or else we won't get callback for rmprocs! + # Timeout so we don't stall forever if something breaks + starttime = time() + while !first_rescheduled_thunk[] && (time() - starttime < 10.0) + sleep(0.1) + end + @test first_rescheduled_thunk[] - # Use single to force scheduler to make use of all workers since we assert it below - ts = delayed(vcat)((delayed(testfun; single=ps1[mod1(i, end)])(i) for i in 1:10)...) + rmprocs!(ctx, ps1[3:end]) + @test length(procs(ctx)) == 2 - # Use FilterLog as a callback function. - nprocs_removed = Ref(0) - first_rescheduled_thunk=Ref(false) - rmproctrigger = Dagger.FilterLog(Dagger.NoOpLog()) do event - if typeof(event) == Dagger.Event{:finish} && event.category === :cleanup_proc - nprocs_removed[] += 1 - end - if typeof(event) == Dagger.Event{:start} && event.category === :add_thunk - first_rescheduled_thunk[] = true + # Timeout so we don't stall forever if something breaks + starttime = time() + while (nprocs_removed[] < 2) && (time() - starttime < 10.0) + sleep(0.01) end - return false - end + # this will fail if we timeout. Verify that we get the logevent for :cleanup_proc + @test nprocs_removed[] >= 2 - ctx = Context(ps1; log_sink=rmproctrigger) - job = @async collect(ctx, ts) + @everywhere ps1 blocked=false - # Must wait for this or else we won't get callback for rmprocs! - # Timeout so we don't stall forever if something breaks - starttime = time() - while !first_rescheduled_thunk[] && (time() - starttime < 10.0) - sleep(0.1) - end - @test first_rescheduled_thunk[] + res = fetch(job) + @test res isa Vector - rmprocs!(ctx, ps1[3:end]) - @test length(procs(ctx)) == 2 - - # Timeout so we don't stall forever if something breaks - starttime = time() - while (nprocs_removed[] < 2) && (time() - starttime < 10.0) - sleep(0.01) + @test res[1:4] |> unique |> sort == ps1 + @test all(pid -> pid in ps1[1:2], res[5:end]) + finally + # Prints "From worker X: IOError:" :/ + wait(rmprocs(ps)) end - # this will fail if we timeout. Verify that we get the logevent for :cleanup_proc - @test nprocs_removed[] >= 2 - - @everywhere ps1 blocked=false + end=# - res = fetch(job) - @test res isa Vector - - @test res[1:4] |> unique |> sort == ps1 - @test all(pid -> pid in ps1[1:2], res[5:end]) - finally - # Prints "From worker X: IOError:" :/ - wait(rmprocs(ps)) - end - end=# + @testset "Remove all workers throws" begin + ps = [] + try + ps1 = addprocs(2, exeflags="--project") + append!(ps, ps1) - @testset "Remove all workers throws" begin - ps = [] - try - ps1 = addprocs(2, exeflags="--project") - append!(ps, ps1) + @everywhere vcat(ps1, myid()) $setup - @everywhere vcat(ps1, myid()) $setup + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:16)...) - ts = delayed(vcat)((delayed(testfun)(i) for i in 1:16)...) + ctx = Context(ps1) + job = @async collect(ctx, ts) - ctx = Context(ps1) - job = @async collect(ctx, ts) - - while !istaskstarted(job) - sleep(0.001) - end + while !istaskstarted(job) + sleep(0.001) + end - rmprocs!(ctx, ps1) - @test length(procs(ctx)) == 0 + rmprocs!(ctx, ps1) + @test length(procs(ctx)) == 0 - @everywhere ps1 blocked=false - if VERSION >= v"1.3.0-alpha.110" - @test_throws TaskFailedException fetch(job) - else - @test_throws Exception fetch(job) + @everywhere ps1 blocked=false + if VERSION >= v"1.3.0-alpha.110" + @test_throws TaskFailedException fetch(job) + else + @test_throws Exception fetch(job) + end + finally + wait(rmprocs(ps)) end - finally - wait(rmprocs(ps)) end end end @@ -350,21 +350,44 @@ end @testset "Scheduler algorithms" begin @testset "Signature Calculation" begin - @test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) isa Vector{DataType} - @test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]) isa Dagger.Sch.Signature + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).sig == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).sig_nokw == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw if isdefined(Core, :kwcall) - @test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig_nokw == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw else kw_f = Core.kwfunc(+) - @test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(kw_f), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig == [typeof(kw_f), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig_nokw == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw end - @test Dagger.Sch.signature(+, []) == [typeof(+)] - @test Dagger.Sch.signature(+, [nothing=>1]) == [typeof(+), Int] + @test Dagger.Sch.signature(+, []).sig == [typeof(+)] + @test Dagger.Sch.signature(+, []).sig_nokw == [typeof(+)] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1)]).sig == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1)]).sig_nokw == [typeof(+), Int] c = Dagger.tochunk(1.0) - @test Dagger.Sch.signature(*, [nothing=>c, nothing=>3]) == [typeof(*), Float64, Int] + @test Dagger.Sch.signature(*, [Dagger.Argument(1, c), Dagger.Argument(2, 3)]).sig == [typeof(*), Float64, Int] t = Dagger.@spawn 1+2 - @test Dagger.Sch.signature(/, [nothing=>t, nothing=>c, nothing=>3]) == [typeof(/), Int, Float64, Int] + @test Dagger.Sch.signature(/, [Dagger.Argument(1, t), Dagger.Argument(2, c), Dagger.Argument(3, 3)]).sig == [typeof(/), Int, Float64, Int] end @testset "Cost Estimation" begin @@ -377,13 +400,18 @@ end end state = Dagger.Sch.EAGER_STATE[] - tproc1 = Dagger.ThreadProc(1, 1) - tproc2 = Dagger.ThreadProc(first(workers()), 1) - procs = [tproc1, tproc2] + tproc1_1 = Dagger.ThreadProc(1, 1) + tproc2_1 = Dagger.ThreadProc(first(workers()), 1) + procs = [tproc1_1, tproc2_1] + + # Ensure that this worker has been used at least once + fetch(Dagger.@spawn scope=Dagger.ExactScope(tproc2_1) 1+1) - pres1 = state.worker_time_pressure[1][tproc1] - pres2 = state.worker_time_pressure[first(workers())][tproc2] + #pres1_1 = state.worker_time_pressure[1][tproc1_1] + #pres2_1 = state.worker_time_pressure[first(workers())][tproc2_1] tx_rate = state.transfer_rate[] + tx_xfer_cost = 1e6 + sig_unknown_cost = 1e9 for (args, tx_size) in [ ([1, 2], 0), @@ -407,20 +435,22 @@ end @test est_tx_size == tx_size t = delayed(mynothing)(args...) - inputs = Dagger.Sch.collect_task_inputs(state, t) - sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs) + Dagger.Sch.collect_task_inputs!(state, t) + sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t) - @test tproc1 in sorted_procs - @test tproc2 in sorted_procs + @test tproc1_1 in sorted_procs + @test tproc2_1 in sorted_procs if length(cargs) > 0 - @test sorted_procs[1] == tproc1 - @test sorted_procs[2] == tproc2 + @test sorted_procs[1] == tproc1_1 + @test sorted_procs[2] == tproc2_1 end - @test haskey(costs, tproc1) - @test haskey(costs, tproc2) - @test costs[tproc1] ≈ pres1 # All chunks are local - @test costs[tproc2] ≈ (tx_size/tx_rate) + pres2 # All chunks are remote + @test haskey(costs, tproc1_1) + @test haskey(costs, tproc2_1) + @test costs[tproc1_1] ≈ #=pres1_1 +=# sig_unknown_cost # All chunks are local, and this signature is unknown + if nprocs() > 1 + @test costs[tproc2_1] ≈ (tx_size/tx_rate) + tx_xfer_cost + #=pres2_1 +=# sig_unknown_cost # All chunks are remote, and this signature is unknown + end end end end @@ -464,9 +494,12 @@ end @test haskey(ids, d_id) @test length(ids[d_id]) == 0 # no one waiting on our result - @test length(ids[a_id]) == 0 # b and c finished, our result is unneeded - @test length(ids[b_id]) == 1 # d is still executing - @test length(ids[c_id]) == 1 # d is still executing + + @test haskey(ids, a_id) + @test length(ids[a_id]) == 0 # b and c finished, our result was unneeded + + @test length(ids[b_id]) == 1 # d was still executing + @test length(ids[c_id]) == 1 # d was still executing @test pop!(ids[b_id]) == d_id @test pop!(ids[c_id]) == d_id end @@ -536,12 +569,32 @@ end end @testset "Cancellation" begin + # Ready task cancellation + start_time = time_ns() t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) + Dagger.cancel!(t) + @test timedwait(()->istaskdone(t), 10) == :ok + if istaskdone(t) + @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) + @test (time_ns() - start_time) * 1e-9 < 100 + end + + # Running task cancellation start_time = time_ns() + t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) + sleep(0.1) # Give the scheduler a chance to schedule the task Dagger.cancel!(t) - @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) + @test timedwait(()->istaskdone(t), 10) == :ok + if istaskdone(t) + @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) + @test (time_ns() - start_time) * 1e-9 < 100 + end + + # Normal task execution + start_time = time_ns() t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() - fetch(t) - finish_time = time_ns() - @test (finish_time - start_time) * 1e-9 < 100 + @test timedwait(()->istaskdone(t), 10) == :ok + if istaskdone(t) + @test (time_ns() - start_time) * 1e-9 < 100 + end end diff --git a/test/task-affinity.jl b/test/task-affinity.jl index cd57e273b..f1e26295a 100644 --- a/test/task-affinity.jl +++ b/test/task-affinity.jl @@ -6,9 +6,13 @@ @assert Dagger.Sch.unwrap_nested_exception(err) isa Dagger.Sch.SchedulingException return Dagger.InvalidScope end - get_compute_scope(x::DTask) = Dagger.Sch._find_thunk(x).compute_scope + function get_compute_scope(x::DTask) + thunk = Dagger.Sch._find_thunk(x) + return @something(thunk.options.compute_scope, thunk.options.scope, Dagger.DefaultScope()) + end - get_result_scope(x::DTask) = Dagger.Sch._find_thunk(x).result_scope + get_result_scope(x::DTask) = + @something(Dagger.Sch._find_thunk(x).options.result_scope, Dagger.AnyScope()) get_final_result_scope(x::DTask) = @something(fetch_or_invalidscope(x), fetch(x; raw=true).scope) @@ -18,18 +22,18 @@ return res end thunk = Dagger.Sch._find_thunk(x) - compute_scope = thunk.compute_scope - result_scope = thunk.result_scope - f_scope = thunk.f isa Dagger.Chunk ? thunk.f.scope : Dagger.AnyScope() + compute_scope = @something(thunk.options.compute_scope, thunk.options.scope, Dagger.DefaultScope()) + result_scope = @something(thunk.options.result_scope, Dagger.AnyScope()) inputs_scopes = Dagger.AbstractScope[] for input in thunk.inputs + input = Dagger.unwrap_weak_checked(Dagger.value(input)) if input isa Dagger.Chunk push!(inputs_scopes, input.scope) else push!(inputs_scopes, Dagger.AnyScope()) end end - return Dagger.constrain(compute_scope, result_scope, f_scope, inputs_scopes...) + return Dagger.constrain(compute_scope, result_scope, inputs_scopes...) end availprocs = collect(Dagger.all_processors()) diff --git a/test/thunk.jl b/test/thunk.jl index 8f4477df6..1ff9517ac 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -49,8 +49,10 @@ end end @testset "@spawn" begin - @test_throws_unwrap ConcurrencyViolationError remotecall_fetch(last(workers())) do - Dagger.Sch.init_eager() + if nprocs() > 1 + @test_throws_unwrap ConcurrencyViolationError remotecall_fetch(last(workers())) do + Dagger.Sch.init_eager() + end end @test Dagger.Sch.EAGER_CONTEXT[] === nothing @testset "per-call" begin @@ -285,52 +287,56 @@ end @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(d) end end - @testset "remote spawn" begin - a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) - @test Dagger.Sch.EAGER_INIT[] - @test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[])) - @test a isa Dagger.DTask - @test fetch(a) == 3 - - # Mild stress-test - @test dynamic_fib(10) == 55 - - # Errors on remote are correctly scrubbed (#430) - t2 = remotecall_fetch(2) do - t1 = Dagger.@spawn 1+"fail" - Dagger.@spawn t1+1 + if 2 in workers() + @testset "remote spawn" begin + a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) + @test Dagger.Sch.EAGER_INIT[] + @test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[])) + @test a isa Dagger.DTask + @test fetch(a) == 3 + + # Mild stress-test + @test dynamic_fib(10) == 55 + + # Errors on remote are correctly scrubbed (#430) + t2 = remotecall_fetch(2) do + t1 = Dagger.@spawn 1+"fail" + Dagger.@spawn t1+1 + end + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2) end - @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2) end - @testset "undefined function" begin - # Issues #254, #255 + if nprocs() > 1 + @testset "undefined function" begin + # Issues #254, #255 - # only defined on head node - @eval evil_f(x) = x + # only defined on head node + @eval evil_f(x) = x - eager_thunks = map(1:10) do i - single = isodd(i) ? 1 : first(workers()) - Dagger.@spawn single=single evil_f(i) - end + eager_thunks = map(1:10) do i + single = isodd(i) ? 1 : first(workers()) + Dagger.@spawn single=single evil_f(i) + end - errored(t) = try - fetch(t) - false - catch - true + errored(t) = try + fetch(t) + false + catch + true + end + @test any(t->errored(t), eager_thunks) + @test any(t->!errored(t), eager_thunks) end - @test any(t->errored(t), eager_thunks) - @test any(t->!errored(t), eager_thunks) end @testset "function chunks" begin @testset "lazy API" begin a = delayed(+)(1,2) @test !(a.f isa Chunk) - @test a.compute_scope == Dagger.DefaultScope() + @test a.options.scope == nothing a = delayed(+; scope=NodeScope())(1,2) @test !(a.f isa Chunk) - @test a.compute_scope isa NodeScope + @test a.options.scope isa NodeScope @testset "Scope Restrictions" begin pls = ProcessLockedStruct(Ptr{Int}(42)) @@ -340,7 +346,9 @@ end @test_skip !all(x->x==43, collect(ctx, delayed(vcat)([delayed(pls)(1) for i in 1:10]...))) # Positive tests (no serialization) @test all(x->x==43, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope())(1) for i in 1:10]...))) - @test all(x->x==1, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope(first(workers())))(1) for i in 1:10]...))) + if nprocs() > 1 + @test all(x->x==1, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope(first(workers())))(1) for i in 1:10]...))) + end end @testset "Processor Data Movement" begin @everywhere Dagger.add_processor_callback!(()->MulProc(), :mulproc) @@ -353,7 +361,7 @@ end _a = Dagger.@spawn scope=NodeScope() 1+2 a = Dagger.Sch._find_thunk(_a) @test !(a.f isa Chunk) - @test a.compute_scope isa NodeScope + @test a.options.scope isa NodeScope end end @testset "parent fetch child, one thread" begin diff --git a/test/util.jl b/test/util.jl index 1131a9ebe..7d15c5810 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,7 +14,7 @@ end replace_obj!(ex::Symbol, obj) = Expr(:(.), obj, QuoteNode(ex)) replace_obj!(ex, obj) = ex function _test_throws_unwrap(terr, ex; to_match=[]) - @gensym oerr rerr + @gensym oerr rerr bt match_expr = Expr(:block) for m in to_match if m.head == :(=) @@ -35,18 +35,33 @@ function _test_throws_unwrap(terr, ex; to_match=[]) end end quote + $bt = nothing $oerr, $rerr = try nothing, $(esc(ex)) catch err + $bt = catch_backtrace() (err, Dagger.Sch.unwrap_nested_exception(err)) end if $terr isa Tuple @test $oerr isa $terr[1] @test $rerr isa $terr[2] + if $rerr isa $terr[2] + $match_expr + else + println("Full error:") + Base.showerror(stdout, $oerr) + Base.show_backtrace(stdout, $bt) + end else @test $rerr isa $terr + if $rerr isa $terr + $match_expr + else + println("Full error:") + Base.showerror(stdout, $oerr) + Base.show_backtrace(stdout, $bt) + end end - $match_expr end end function _test_throws_unwrap(terr, args...)