@@ -30,7 +30,6 @@ function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T
3030 end
3131 @dagdebug nothing :stream_put " [$(uid ()) ] adding $value "
3232 for buffer in values (store. buffers)
33- # elem = StreamElement(value)
3433 push! (buffer, value)
3534 end
3635 notify (store. lock)
@@ -139,11 +138,31 @@ end
139138remove_waiters! (stream:: Stream , waiter:: Integer ) =
140139 remove_waiters! (stream:: Stream , Int[waiter])
141140
141+ struct NullStream end
142+ Base. put! (ns:: NullStream , x) = nothing
143+ Base. take! (ns:: NullStream ) = throw (ConcurrencyViolationError (" Cannot `take!` from a `NullStream`" ))
144+
145+ mutable struct StreamWrapper{S}
146+ stream:: S
147+ open:: Bool
148+ StreamWrapper (stream:: S ) where S = new {S} (stream, true )
149+ end
150+ Base. isopen (sw:: StreamWrapper ) = sw. open
151+ Base. close (sw:: StreamWrapper ) = (sw. open = false ;)
152+ function Base. put! (sw:: StreamWrapper , x)
153+ isopen (sw) || throw (InvalidStateException (" Stream is closed." , :closed ))
154+ put! (sw. stream, x)
155+ end
156+ function Base. take! (sw:: StreamWrapper )
157+ isopen (sw) || throw (InvalidStateException (" Stream is closed." , :closed ))
158+ take! (sw. stream)
159+ end
160+
142161struct StreamingTaskQueue <: AbstractTaskQueue
143162 tasks:: Vector{Pair{EagerTaskSpec,EagerThunk}}
144- self_streams:: Dict{UInt,Stream }
163+ self_streams:: Dict{UInt,Any }
145164 StreamingTaskQueue () = new (Pair{EagerTaskSpec,EagerThunk}[],
146- Dict {UInt,Stream } ())
165+ Dict {UInt,Any } ())
147166end
148167
149168function enqueue! (queue:: StreamingTaskQueue , spec:: Pair{EagerTaskSpec,EagerThunk} )
@@ -164,7 +183,20 @@ function initialize_streaming!(self_streams, spec, task)
164183 # We treat non-dominating error paths as unreachable
165184 T_old = filter (t-> t != = Union{}, T_old)
166185 T = task. metadata. return_type = ! isempty (T_old) ? Union{T_old... } : Any
167- stream = Stream {T} ()
186+ if haskey (spec. options, :stream )
187+ if spec. options. stream != = nothing
188+ # Use the user-provided stream
189+ @warn " Replace StreamWrapper with Stream" maxlog= 1
190+ stream = StreamWrapper (spec. options. stream)
191+ else
192+ # Use a non-readable, non-writing stream
193+ stream = StreamWrapper (NullStream ())
194+ end
195+ spec. options = NamedTuple (filter (opt -> opt[1 ] != :stream , Base. pairs (spec. options)))
196+ else
197+ # Create a built-in Stream object
198+ stream = Stream {T} ()
199+ end
168200 self_streams[task. uid] = stream
169201
170202 spec. f = StreamingFunction (spec. f, stream)
@@ -190,56 +222,33 @@ function spawn_streaming(f::Base.Callable)
190222end
191223
192224struct FinishedStreaming{T}
193- value:: T
225+ value:: Union{Some{T},Nothing}
194226end
195- finish_streaming (value= nothing ) = FinishedStreaming (value)
227+ finish_streaming (value) = FinishedStreaming {Any} (Some {T} (value))
228+ finish_streaming () = FinishedStreaming {Union{}} (nothing )
196229
197- struct StreamingFunction{F, T }
230+ struct StreamingFunction{F, S }
198231 f:: F
199- stream:: Stream{T}
232+ stream:: S
200233end
201234function (sf:: StreamingFunction )(args... ; kwargs... )
202235 @nospecialize sf args kwargs
203236 result = nothing
204- stream_args = Base. mapany (identity, args)
205- stream_kwargs = Base. mapany (identity, kwargs)
206237 thunk_id = tid ()
207- # FIXME : Fetch from worker 1
208- uid = lock (Sch. EAGER_ID_MAP) do id_map
209- for (uid, otid) in id_map
210- if thunk_id == otid
211- return uid
238+ @warn " Fetch from worker 1 more efficiently" maxlog= 1
239+ uid = remotecall_fetch (1 , thunk_id) do thunk_id
240+ lock (Sch. EAGER_ID_MAP) do id_map
241+ for (uid, otid) in id_map
242+ if thunk_id == otid
243+ return uid
244+ end
212245 end
213246 end
214247 end
215248 try
216- while true
217- # Get values from Stream args/kwargs
218- for (idx, arg) in enumerate (args)
219- if arg isa Stream
220- stream_args[idx] = take! (arg, uid)
221- end
222- end
223- for (idx, (pos, arg)) in enumerate (kwargs)
224- if arg isa Stream
225- stream_kwargs[idx] = pos => take! (arg, uid)
226- end
227- end
228-
229- # Run a single cycle of f
230- stream_result = sf. f (stream_args... ; stream_kwargs... )
231-
232- # Exit streaming on graceful request
233- if stream_result isa FinishedStreaming
234- return stream_result. value
235- end
236-
237- # Put the result into the output stream
238- put! (sf. stream, stream_result)
239-
240- # Allow other tasks to run
241- yield ()
242- end
249+ kwarg_names = map (name-> Val {name} (), map (first, (kwargs... ,)))
250+ kwarg_values = map (last, (kwargs... ,))
251+ return stream! (sf, uid, (args... ,), kwarg_names, kwarg_values)
243252 finally
244253 # Remove ourself as a waiter for upstream Streams
245254 streams = Set {Stream} ()
@@ -263,9 +272,55 @@ function (sf::StreamingFunction)(args...; kwargs...)
263272 close (sf. stream)
264273 end
265274end
275+ # N.B We specialize to minimize/eliminate allocations
276+ function stream! (sf:: StreamingFunction , uid,
277+ args:: Tuple , kwarg_names:: Tuple , kwarg_values:: Tuple )
278+ while true
279+ # @time begin
280+ # Get values from Stream args/kwargs
281+ stream_args = _stream_take_values! (args)
282+ stream_kwarg_values = _stream_take_values! (kwarg_values)
283+ stream_kwargs = _stream_namedtuple (kwarg_names, stream_kwarg_values)
284+
285+ # Run a single cycle of f
286+ stream_result = sf. f (stream_args... ; stream_kwargs... )
287+
288+ # Exit streaming on graceful request
289+ if stream_result isa FinishedStreaming
290+ @info " Terminating!"
291+ if stream_result. value != = nothing
292+ value = something (stream_result. value)
293+ put! (sf. stream, value)
294+ return value
295+ end
296+ return nothing
297+ end
266298
267- # FIXME : Ensure this gets cleaned up
268- const EAGER_THUNK_STREAMS = LockedObject (Dict {UInt,Stream} ())
299+ # Put the result into the output stream
300+ put! (sf. stream, stream_result)
301+ # end
302+ end
303+ end
304+ function _stream_take_values! (args)
305+ return ntuple (length (args)) do idx
306+ arg = args[idx]
307+ if arg isa Stream
308+ take! (arg, uid)
309+ elseif arg isa Union{AbstractChannel,RemoteChannel,StreamWrapper} # FIXME : Use trait query
310+ take! (arg)
311+ else
312+ arg
313+ end
314+ end
315+ end
316+ @inline @generated function _stream_namedtuple (kwarg_names:: Tuple ,
317+ stream_kwarg_values:: Tuple )
318+ name_ex = Expr (:tuple , map (name-> QuoteNode (name. parameters[1 ]), kwarg_names. parameters)... )
319+ NT = :(NamedTuple{$ name_ex,$ stream_kwarg_values})
320+ return :($ NT (stream_kwarg_values))
321+ end
322+
323+ const EAGER_THUNK_STREAMS = LockedObject (Dict {UInt,Any} ())
269324function task_to_stream (uid:: UInt )
270325 if myid () != 1
271326 return remotecall_fetch (task_to_stream, 1 , uid)
@@ -310,7 +365,9 @@ function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self
310365 # Adjust waiter count of Streams with dependencies
311366 for (uid, waiters) in stream_waiter_changes
312367 stream = task_to_stream (uid)
313- add_waiters! (stream, waiters)
368+ if stream isa Stream # FIXME : Use trait query
369+ add_waiters! (stream, waiters)
370+ end
314371 end
315372end
316373
0 commit comments