1- mutable struct StreamStore
1+ mutable struct StreamStore{T}
22 waiters:: Vector{Int}
33 buffers:: Dict{Int,Vector{Any}}
44 open:: Bool
55 lock:: Threads.Condition
6- StreamStore () = new (zeros (Int, 0 ), Dict {Int,Vector{Any}} (), true , Threads. Condition ())
6+ StreamStore {T} () where T =
7+ new {T} (zeros (Int, 0 ), Dict {Int,Vector{T}} (),
8+ true , Threads. Condition ())
79end
810tid () = Dagger. Sch. sch_handle (). thunk_id. id
911function uid ()
@@ -16,7 +18,7 @@ function uid()
1618 end
1719 end
1820end
19- function Base. put! (store:: StreamStore , @nospecialize (value))
21+ function Base. put! (store:: StreamStore{T} , @nospecialize (value:: T )) where T
2022 @lock store. lock begin
2123 while length (store. waiters) == 0 && isopen (store)
2224 @dagdebug nothing :stream_put " [$(uid ()) ] no waiters, not putting"
8991mutable struct Stream{T} <: AbstractChannel{T}
9092 ref:: Chunk
9193 function Stream {T} () where T
92- store = tochunk (StreamStore ())
94+ store = tochunk (StreamStore {T} ())
9395 return new {T} (store)
9496 end
9597end
@@ -157,13 +159,16 @@ end
157159function initialize_streaming! (self_streams, spec, task)
158160 if ! isa (spec. f, StreamingFunction)
159161 # Adapt called function for streaming and generate output Streams
160- # FIXME : Infer type
161- stream = Stream ()
162+ T_old = Base. uniontypes (task. metadata. return_type)
163+ T_old = map (t-> (t != = Union{} && t <: FinishedStreaming ) ? only (t. parameters) : t, T_old)
164+ # We treat non-dominating error paths as unreachable
165+ T_old = filter (t-> t != = Union{}, T_old)
166+ T = task. metadata. return_type = ! isempty (T_old) ? Union{T_old... } : Any
167+ stream = Stream {T} ()
162168 self_streams[task. uid] = stream
163169
164170 spec. f = StreamingFunction (spec. f, stream)
165- # FIXME : Generalize to other processors
166- spec. options = merge (spec. options, (;occupancy= Dict (ThreadProc=> 0 )))
171+ spec. options = merge (spec. options, (;occupancy= Dict (Any=> 0 )))
167172
168173 # Register Stream globally
169174 remotecall_wait (1 , task. uid, stream) do uid, stream
0 commit comments