Skip to content

Commit 3c5c389

Browse files
committed
streaming: Add DAG teardown option
1 parent cbe64f8 commit 3c5c389

File tree

4 files changed

+85
-25
lines changed

4 files changed

+85
-25
lines changed

docs/src/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,4 +427,5 @@ wait(t)
427427
The above example demonstrates a streaming region that generates random numbers
428428
continuously and writes each random number to a file. The streaming region is
429429
terminated when a random number less than 0.01 is generated, which is done by
430-
calling `Dagger.finish_stream()` (this exits the current streaming task).
430+
calling `Dagger.finish_stream()` (this terminates the current task, and will
431+
also terminate all streaming tasks launched by `spawn_streaming`).

docs/src/streaming.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@ end
7979
```
8080

8181
If you want to stop the streaming DAG and tear it all down, you can call
82-
`Dagger.cancel!.(all_vals)` and `Dagger.cancel!.(all_vals_written)` to
83-
terminate each streaming task. In the future, a more convenient way to tear
84-
down a full DAG will be added; for now, each task must be cancelled individually.
82+
`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to
83+
terminate all streaming tasks.
8584

8685
Alternatively, tasks can stop themselves from the inside with
8786
`finish_stream`, optionally returning a value that can be `fetch`'d. Let's

src/stream.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,37 @@ function initialize_streaming!(self_streams, spec, task)
426426
end
427427
end
428428

429-
function spawn_streaming(f::Base.Callable)
429+
"""
430+
Starts a streaming region, within which all tasks run continuously and
431+
concurrently. Any `DTask` argument that is itself a streaming task will be
432+
treated as a streaming input/output. The streaming region will automatically
433+
handle the buffering and synchronization of these tasks' values.
434+
435+
# Keyword Arguments
436+
- `teardown::Bool=true`: If `true`, the streaming region will automatically
437+
cancel all tasks if any task fails or is cancelled. Otherwise, a failing task
438+
will not cancel the other tasks, which will continue running.
439+
"""
440+
function spawn_streaming(f::Base.Callable; teardown::Bool=true)
430441
queue = StreamingTaskQueue()
431442
result = with_options(f; task_queue=queue)
432443
if length(queue.tasks) > 0
433444
finalize_streaming!(queue.tasks, queue.self_streams)
434445
enqueue!(queue.tasks)
446+
447+
if teardown
448+
# Start teardown monitor
449+
dtasks = map(last, queue.tasks)::Vector{DTask}
450+
Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin
451+
# Wait for any task to finish
452+
waitany(dtasks)
453+
454+
# Cancel all tasks
455+
for task in dtasks
456+
cancel!(task; graceful=false)
457+
end
458+
end)
459+
end
435460
end
436461
return result
437462
end

test/streaming.jl

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ for idx in 1:5
8080
@testset "Single Task Control Flow ($scope_str)" begin
8181
@test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do
8282
local x
83-
Dagger.spawn_streaming() do
83+
Dagger.spawn_streaming(;teardown=false) do
8484
x = Dagger.@spawn scope=rand(scopes) () -> begin
8585
y = rand()
8686
sleep(1)
@@ -92,15 +92,15 @@ for idx in 1:5
9292

9393
@test test_finishes("Single task without result") do
9494
local x
95-
Dagger.spawn_streaming() do
95+
Dagger.spawn_streaming(;teardown=false) do
9696
x = Dagger.@spawn scope=rand(scopes) rand()
9797
end
9898
@test fetch(x) === nothing
9999
end
100100

101101
@test test_finishes("Single task with result"; max_evals=1_000_000) do
102102
local x
103-
Dagger.spawn_streaming() do
103+
Dagger.spawn_streaming(;teardown=false) do
104104
x = Dagger.@spawn scope=rand(scopes) () -> begin
105105
x = rand()
106106
if x < 0.1
@@ -116,7 +116,7 @@ for idx in 1:5
116116
@testset "Non-Streaming Inputs ($scope_str)" begin
117117
@test test_finishes("() -> A") do
118118
local A
119-
Dagger.spawn_streaming() do
119+
Dagger.spawn_streaming(;teardown=false) do
120120
A = Dagger.@spawn scope=rand(scopes) accumulator()
121121
end
122122
@test fetch(A) === nothing
@@ -127,7 +127,7 @@ for idx in 1:5
127127
end
128128
@test test_finishes("42 -> A") do
129129
local A
130-
Dagger.spawn_streaming() do
130+
Dagger.spawn_streaming(;teardown=false) do
131131
A = Dagger.@spawn scope=rand(scopes) accumulator(42)
132132
end
133133
@test fetch(A) === nothing
@@ -138,7 +138,7 @@ for idx in 1:5
138138
end
139139
@test test_finishes("(42, 43) -> A") do
140140
local A
141-
Dagger.spawn_streaming() do
141+
Dagger.spawn_streaming(;teardown=false) do
142142
A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43)
143143
end
144144
@test fetch(A) === nothing
@@ -152,7 +152,7 @@ for idx in 1:5
152152
@testset "Non-Streaming Outputs ($scope_str)" begin
153153
@test test_finishes("x -> A") do
154154
local x, A
155-
Dagger.spawn_streaming() do
155+
Dagger.spawn_streaming(;teardown=false) do
156156
x = Dagger.@spawn scope=rand(scopes) rand()
157157
end
158158
Dagger._without_options() do
@@ -168,7 +168,7 @@ for idx in 1:5
168168

169169
@test test_finishes("x -> (A, B)") do
170170
local x, A, B
171-
Dagger.spawn_streaming() do
171+
Dagger.spawn_streaming(;teardown=false) do
172172
x = Dagger.@spawn scope=rand(scopes) rand()
173173
end
174174
Dagger._without_options() do
@@ -188,10 +188,45 @@ for idx in 1:5
188188
end
189189
end
190190

191+
@testset "Teardown" begin
192+
@test test_finishes("teardown=true"; max_evals=1_000_000, ignore_timeout=true) do
193+
local x, y
194+
Dagger.spawn_streaming(;teardown=true) do
195+
x = Dagger.@spawn scope=rand(scopes) () -> begin
196+
sleep(0.1)
197+
return rand()
198+
end
199+
y = Dagger.with_options(;stream_max_evals=10) do
200+
Dagger.@spawn scope=rand(scopes) identity(x)
201+
end
202+
end
203+
@test fetch(y) === nothing
204+
sleep(1) # Wait for teardown
205+
@test istaskdone(x)
206+
fetch(x)
207+
end
208+
@test !test_finishes("teardown=false"; max_evals=1_000_000, ignore_timeout=true) do
209+
local x, y
210+
Dagger.spawn_streaming(;teardown=false) do
211+
x = Dagger.@spawn scope=rand(scopes) () -> begin
212+
sleep(0.1)
213+
return rand()
214+
end
215+
y = Dagger.with_options(;stream_max_evals=10) do
216+
Dagger.@spawn scope=rand(scopes) identity(x)
217+
end
218+
end
219+
@test fetch(y) === nothing
220+
sleep(1) # Wait to ensure `x` task is still running
221+
@test !istaskdone(x)
222+
@test_throws_unwrap InterruptException fetch(x)
223+
end
224+
end
225+
191226
@testset "Multiple Tasks ($scope_str)" begin
192227
@test test_finishes("x -> A") do
193228
local x, A
194-
Dagger.spawn_streaming() do
229+
Dagger.spawn_streaming(;teardown=false) do
195230
x = Dagger.@spawn scope=rand(scopes) rand()
196231
A = Dagger.@spawn scope=rand(scopes) accumulator(x)
197232
end
@@ -205,7 +240,7 @@ for idx in 1:5
205240

206241
@test test_finishes("(x, A)") do
207242
local x, A
208-
Dagger.spawn_streaming() do
243+
Dagger.spawn_streaming(;teardown=false) do
209244
x = Dagger.@spawn scope=rand(scopes) rand()
210245
A = Dagger.@spawn scope=rand(scopes) accumulator(1.0)
211246
end
@@ -219,7 +254,7 @@ for idx in 1:5
219254

220255
@test test_finishes("x -> y -> A") do
221256
local x, y, A
222-
Dagger.spawn_streaming() do
257+
Dagger.spawn_streaming(;teardown=false) do
223258
x = Dagger.@spawn scope=rand(scopes) rand()
224259
y = Dagger.@spawn scope=rand(scopes) x+1
225260
A = Dagger.@spawn scope=rand(scopes) accumulator(y)
@@ -235,7 +270,7 @@ for idx in 1:5
235270

236271
@test test_finishes("x -> (y, A)") do
237272
local x, y, A
238-
Dagger.spawn_streaming() do
273+
Dagger.spawn_streaming(;teardown=false) do
239274
x = Dagger.@spawn scope=rand(scopes) rand()
240275
y = Dagger.@spawn scope=rand(scopes) x+1
241276
A = Dagger.@spawn scope=rand(scopes) accumulator(x)
@@ -251,7 +286,7 @@ for idx in 1:5
251286

252287
@test test_finishes("(x, y) -> A") do
253288
local x, y, A
254-
Dagger.spawn_streaming() do
289+
Dagger.spawn_streaming(;teardown=false) do
255290
x = Dagger.@spawn scope=rand(scopes) rand()
256291
y = Dagger.@spawn scope=rand(scopes) rand()
257292
A = Dagger.@spawn scope=rand(scopes) accumulator(x, y)
@@ -267,7 +302,7 @@ for idx in 1:5
267302

268303
@test test_finishes("(x, y) -> z -> A") do
269304
local x, y, z, A
270-
Dagger.spawn_streaming() do
305+
Dagger.spawn_streaming(;teardown=false) do
271306
x = Dagger.@spawn scope=rand(scopes) rand()
272307
y = Dagger.@spawn scope=rand(scopes) rand()
273308
z = Dagger.@spawn scope=rand(scopes) x + y
@@ -285,7 +320,7 @@ for idx in 1:5
285320

286321
@test test_finishes("x -> (y, z) -> A") do
287322
local x, y, z, A
288-
Dagger.spawn_streaming() do
323+
Dagger.spawn_streaming(;teardown=false) do
289324
x = Dagger.@spawn scope=rand(scopes) rand()
290325
y = Dagger.@spawn scope=rand(scopes) x + 1
291326
z = Dagger.@spawn scope=rand(scopes) x + 2
@@ -303,7 +338,7 @@ for idx in 1:5
303338

304339
@test test_finishes("(x, y) -> z -> (A, B)") do
305340
local x, y, z, A, B
306-
Dagger.spawn_streaming() do
341+
Dagger.spawn_streaming(;teardown=false) do
307342
x = Dagger.@spawn scope=rand(scopes) rand()
308343
y = Dagger.@spawn scope=rand(scopes) rand()
309344
z = Dagger.@spawn scope=rand(scopes) x + y
@@ -328,7 +363,7 @@ for idx in 1:5
328363
for T in (Float64, Int32, BigFloat)
329364
@test test_finishes("Stream eltype $T") do
330365
local x, A
331-
Dagger.spawn_streaming() do
366+
Dagger.spawn_streaming(;teardown=false) do
332367
x = Dagger.@spawn scope=rand(scopes) rand(T)
333368
A = Dagger.@spawn scope=rand(scopes) accumulator(x)
334369
end
@@ -344,13 +379,13 @@ for idx in 1:5
344379

345380
@testset "Max Evals ($scope_str)" begin
346381
@test test_finishes("max_evals=0"; max_evals=0) do
347-
@test_throws ArgumentError Dagger.spawn_streaming() do
382+
@test_throws ArgumentError Dagger.spawn_streaming(;teardown=false) do
348383
A = Dagger.@spawn scope=rand(scopes) accumulator()
349384
end
350385
end
351386
@test test_finishes("max_evals=1"; max_evals=1) do
352387
local A
353-
Dagger.spawn_streaming() do
388+
Dagger.spawn_streaming(;teardown=false) do
354389
A = Dagger.@spawn scope=rand(scopes) accumulator()
355390
end
356391
@test fetch(A) === nothing
@@ -360,7 +395,7 @@ for idx in 1:5
360395
end
361396
@test test_finishes("max_evals=100"; max_evals=100) do
362397
local A
363-
Dagger.spawn_streaming() do
398+
Dagger.spawn_streaming(;teardown=false) do
364399
A = Dagger.@spawn scope=rand(scopes) accumulator()
365400
end
366401
@test fetch(A) === nothing

0 commit comments

Comments
 (0)