@@ -40,9 +40,6 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B}
4040 end
4141 @dagdebug thunk_id :stream " adding $value ($(length (store. output_streams)) outputs)"
4242 for output_uid in keys (store. output_streams)
43- if ! haskey (store. output_buffers, output_uid)
44- initialize_output_stream! (store, output_uid)
45- end
4643 buffer = store. output_buffers[output_uid]
4744 while isfull (buffer)
4845 if ! isopen (store)
@@ -238,23 +235,32 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
238235 stream_pull_values! (input_fetcher, IT, our_store, input_stream, buffer)
239236 end
240237 catch err
241- if err isa InterruptException || (err isa InvalidStateException && ! isopen (buffer))
238+ unwrapped_err = Sch. unwrap_nested_exception (err)
239+ if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && ! isopen (input_fetcher. chan))
242240 return
243241 else
244- rethrow (err )
242+ rethrow ()
245243 end
246244 finally
245+ # Close the buffer because there will be no more values put into
246+ # it. We don't close the entire store because there might be some
247+ # remaining elements in the buffer to process and send to downstream
248+ # tasks.
249+ close (buffer)
247250 @dagdebug STREAM_THUNK_ID[] :stream " input stream closed"
248251 end
249252 end )
250253 return StreamingValue (buffer)
251254end
252255initialize_input_stream! (our_store:: StreamStore , arg) = arg
253256function initialize_output_stream! (our_store:: StreamStore{T,B} , output_uid:: UInt ) where {T,B}
254- @assert islocked (our_store. lock)
255257 @dagdebug STREAM_THUNK_ID[] :stream " initializing output stream $output_uid "
256- buffer = initialize_stream_buffer (B, T, our_store. output_buffer_amount)
257- our_store. output_buffers[output_uid] = buffer
258+ local buffer
259+ @lock our_store. lock begin
260+ buffer = initialize_stream_buffer (B, T, our_store. output_buffer_amount)
261+ our_store. output_buffers[output_uid] = buffer
262+ end
263+
258264 our_uid = our_store. uid
259265 output_stream = our_store. output_streams[output_uid]
260266 output_fetcher = our_store. output_fetchers[output_uid]
@@ -279,6 +285,7 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt
279285 rethrow (err)
280286 end
281287 finally
288+ close (output_fetcher. chan)
282289 @dagdebug thunk_id :stream " output stream closed"
283290 end
284291 end )
@@ -476,6 +483,17 @@ struct FinishStream{T,R}
476483 result:: R
477484end
478485
486+ """
487+ finish_stream(value=nothing; result=nothing)
488+
489+ Tell Dagger to stop executing the streaming function and all of its downstream
490+ [`DTask`](@ref)'s.
491+
492+ # Arguments
493+ - `value`: The final value to be returned by the streaming function. This will
494+ be passed to all downstream [`DTask`](@ref)'s.
495+ - `result`: The value that will be returned by `fetch()`'ing the [`DTask`](@ref).
496+ """
479497finish_stream (value:: T ; result:: R = nothing ) where {T,R} = FinishStream {T,R} (Some {T} (value), result)
480498
481499finish_stream (; result:: R = nothing ) where R = FinishStream {Union{},R} (nothing , result)
@@ -577,6 +595,16 @@ function stream!(sf::StreamingFunction, uid,
577595 f = move (thunk_processor (), sf. f)
578596 counter = 0
579597
598+ # Initialize output streams. We can't do this in add_waiters!() because the
599+ # output handlers depend on the DTaskTLS, so they have to be set up from
600+ # within the DTask.
601+ store = sf. stream. store
602+ for output_uid in keys (store. output_streams)
603+ if ! haskey (store. output_buffers, output_uid)
604+ initialize_output_stream! (store, output_uid)
605+ end
606+ end
607+
580608 while true
581609 # Yield to other (streaming) tasks
582610 yield ()
@@ -592,8 +620,21 @@ function stream!(sf::StreamingFunction, uid,
592620 end
593621
594622 # Get values from Stream args/kwargs
595- stream_args = _stream_take_values! (args)
596- stream_kwarg_values = _stream_take_values! (kwarg_values)
623+ local stream_args, stream_kwarg_values
624+ try
625+ stream_args = _stream_take_values! (args)
626+ stream_kwarg_values = _stream_take_values! (kwarg_values)
627+ catch ex
628+ if ex isa InvalidStateException
629+ # This means a buffer has been closed because an upstream task
630+ # finished.
631+ @dagdebug STREAM_THUNK_ID[] :stream " Upstream task finished, returning"
632+ return nothing
633+ else
634+ rethrow ()
635+ end
636+ end
637+
597638 stream_kwargs = _stream_namedtuple (kwarg_names, stream_kwarg_values)
598639
599640 if length (stream_args) > 0 || length (stream_kwarg_values) > 0
0 commit comments