diff --git a/docs/src/index.md b/docs/src/index.md index 87b4ea174..152b95cc5 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -427,4 +427,5 @@ wait(t) The above example demonstrates a streaming region that generates random numbers continuously and writes each random number to a file. The streaming region is terminated when a random number less than 0.01 is generated, which is done by -calling `Dagger.finish_stream()` (this exits the current streaming task). +calling `Dagger.finish_stream()` (this terminates the current task, and will +also terminate all streaming tasks launched by `spawn_streaming`). diff --git a/docs/src/streaming.md b/docs/src/streaming.md index 25060e1b2..41c111e82 100644 --- a/docs/src/streaming.md +++ b/docs/src/streaming.md @@ -79,9 +79,8 @@ end ``` If you want to stop the streaming DAG and tear it all down, you can call -`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to -terminate each streaming task. In the future, a more convenient way to tear -down a full DAG will be added; for now, each task must be cancelled individually. +`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to +terminate all streaming tasks. Alternatively, tasks can stop themselves from the inside with `finish_stream`, optionally returning a value that can be `fetch`'d. Let's diff --git a/src/dtask.jl b/src/dtask.jl index 98f74005a..b597db5fa 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false) end return fetch(t.future; raw) end +function waitany(tasks::Vector{DTask}) + if isempty(tasks) + return + end + cond = Threads.Condition() + for task in tasks + Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end) + end + @lock cond wait(cond) + return +end +function waitall(tasks::Vector{DTask}) + if isempty(tasks) + return + end + @sync for task in tasks + Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end + end + return +end function Base.show(io::IO, t::DTask) status = if istaskstarted(t) isready(t) ? "finished" : "running" diff --git a/src/stream.jl b/src/stream.jl index 81becd5ac..07a3dae95 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task) end end -function spawn_streaming(f::Base.Callable) +""" +Starts a streaming region, within which all tasks run continuously and +concurrently. Any `DTask` argument that is itself a streaming task will be +treated as a streaming input/output. The streaming region will automatically +handle the buffering and synchronization of these tasks' values. + +# Keyword Arguments +- `teardown::Bool=true`: If `true`, the streaming region will automatically + cancel all tasks if any task fails or is cancelled. Otherwise, a failing task + will not cancel the other tasks, which will continue running. +""" +function spawn_streaming(f::Base.Callable; teardown::Bool=true) queue = StreamingTaskQueue() result = with_options(f; task_queue=queue) if length(queue.tasks) > 0 finalize_streaming!(queue.tasks, queue.self_streams) enqueue!(queue.tasks) + + if teardown + # Start teardown monitor + dtasks = map(last, queue.tasks)::Vector{DTask} + Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin + # Wait for any task to finish + waitany(dtasks) + + # Cancel all tasks + for task in dtasks + cancel!(task; graceful=false) + end + end) + end end return result end diff --git a/src/utils/tasks.jl b/src/utils/tasks.jl index c2796cf21..ddd8da2ee 100644 --- a/src/utils/tasks.jl +++ b/src/utils/tasks.jl @@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer) end @assert Threads.threadid(task) == tid "jl_set_task_tid failed!" end + +if isdefined(Base, :waitany) +import Base: waitany, waitall +else +# Vendored from Base +# License is MIT +waitany(tasks; throw=true) = _wait_multiple(tasks, throw) +waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast) +function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false) + tasks = Task[] + + for t in waiting_tasks + t isa Task || error("Expected an iterator of `Task` object") + push!(tasks, t) + end + + if (all && !failfast) || length(tasks) <= 1 + exception = false + # Force everything to finish synchronously for the case of waitall + # with failfast=false + for t in tasks + _wait(t) + exception |= istaskfailed(t) + end + if exception && throwexc + exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks, Task[] + end + end + + exception = false + nremaining::Int = length(tasks) + done_mask = falses(nremaining) + for (i, t) in enumerate(tasks) + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + else + done_mask[i] = false + end + end + + if nremaining == 0 + return tasks, Task[] + elseif any(done_mask) && (!all || (failfast && exception)) + if throwexc && (!all || failfast) && exception + exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks[done_mask], tasks[.~done_mask] + end + end + + chan = Channel{Int}(Inf) + sentinel = current_task() + waiter_tasks = fill(sentinel, length(tasks)) + + for (i, done) in enumerate(done_mask) + done && continue + t = tasks[i] + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + exception && failfast && break + else + waiter = @task put!(chan, i) + waiter.sticky = false + _wait2(t, waiter) + waiter_tasks[i] = waiter + end + end + + while nremaining > 0 + i = take!(chan) + t = tasks[i] + waiter_tasks[i] = sentinel + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + + # stop early if requested, unless there is something immediately + # ready to consume from the channel (using a race-y check) + if (!all || (failfast && exception)) && !isready(chan) + break + end + end + + close(chan) + + if nremaining == 0 + return tasks, Task[] + else + remaining_mask = .~done_mask + for i in findall(remaining_mask) + waiter = waiter_tasks[i] + donenotify = tasks[i].donenotify::ThreadSynchronizer + @lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter) + end + done_tasks = tasks[done_mask] + if throwexc && exception + exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return done_tasks, tasks[remaining_mask] + end + end +end +end diff --git a/test/streaming.jl b/test/streaming.jl index 9eb01312c..c3bf0e406 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -80,7 +80,7 @@ for idx in 1:5 @testset "Single Task Control Flow ($scope_str)" begin @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) () -> begin y = rand() sleep(1) @@ -92,7 +92,7 @@ for idx in 1:5 @test test_finishes("Single task without result") do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end @test fetch(x) === nothing @@ -100,7 +100,7 @@ for idx in 1:5 @test test_finishes("Single task with result"; max_evals=1_000_000) do local x - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) () -> begin x = rand() if x < 0.1 @@ -116,7 +116,7 @@ for idx in 1:5 @testset "Non-Streaming Inputs ($scope_str)" begin @test test_finishes("() -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing @@ -127,7 +127,7 @@ for idx in 1:5 end @test test_finishes("42 -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator(42) end @test fetch(A) === nothing @@ -138,7 +138,7 @@ for idx in 1:5 end @test test_finishes("(42, 43) -> A") do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) end @test fetch(A) === nothing @@ -152,7 +152,7 @@ for idx in 1:5 @testset "Non-Streaming Outputs ($scope_str)" begin @test test_finishes("x -> A") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end Dagger._without_options() do @@ -168,7 +168,7 @@ for idx in 1:5 @test test_finishes("x -> (A, B)") do local x, A, B - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() end Dagger._without_options() do @@ -188,10 +188,45 @@ for idx in 1:5 end end + @testset "Teardown" begin + @test test_finishes("teardown=true"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=true) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait for teardown + @test istaskdone(x) + fetch(x) + end + @test !test_finishes("teardown=false"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait to ensure `x` task is still running + @test !istaskdone(x) + @test_throws_unwrap InterruptException fetch(x) + end + end + @testset "Multiple Tasks ($scope_str)" begin @test test_finishes("x -> A") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x) end @@ -205,7 +240,7 @@ for idx in 1:5 @test test_finishes("(x, A)") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) end @@ -219,7 +254,7 @@ for idx in 1:5 @test test_finishes("x -> y -> A") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x+1 A = Dagger.@spawn scope=rand(scopes) accumulator(y) @@ -235,7 +270,7 @@ for idx in 1:5 @test test_finishes("x -> (y, A)") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x+1 A = Dagger.@spawn scope=rand(scopes) accumulator(x) @@ -251,7 +286,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> A") do local x, y, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) @@ -267,7 +302,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> z -> A") do local x, y, z, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() z = Dagger.@spawn scope=rand(scopes) x + y @@ -285,7 +320,7 @@ for idx in 1:5 @test test_finishes("x -> (y, z) -> A") do local x, y, z, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) x + 1 z = Dagger.@spawn scope=rand(scopes) x + 2 @@ -303,7 +338,7 @@ for idx in 1:5 @test test_finishes("(x, y) -> z -> (A, B)") do local x, y, z, A, B - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand() y = Dagger.@spawn scope=rand(scopes) rand() z = Dagger.@spawn scope=rand(scopes) x + y @@ -328,7 +363,7 @@ for idx in 1:5 for T in (Float64, Int32, BigFloat) @test test_finishes("Stream eltype $T") do local x, A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do x = Dagger.@spawn scope=rand(scopes) rand(T) A = Dagger.@spawn scope=rand(scopes) accumulator(x) end @@ -344,13 +379,13 @@ for idx in 1:5 @testset "Max Evals ($scope_str)" begin @test test_finishes("max_evals=0"; max_evals=0) do - @test_throws ArgumentError Dagger.spawn_streaming() do + @test_throws ArgumentError Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end end @test test_finishes("max_evals=1"; max_evals=1) do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing @@ -360,7 +395,7 @@ for idx in 1:5 end @test test_finishes("max_evals=100"; max_evals=100) do local A - Dagger.spawn_streaming() do + Dagger.spawn_streaming(;teardown=false) do A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing