Skip to content

Commit 78146e6

Browse files
committed
fixup! Add streaming API
1 parent b400189 commit 78146e6

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/stream.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
mutable struct StreamStore
1+
mutable struct StreamStore{T}
22
waiters::Vector{Int}
33
buffers::Dict{Int,Vector{Any}}
44
open::Bool
55
lock::Threads.Condition
6-
StreamStore() = new(zeros(Int, 0), Dict{Int,Vector{Any}}(), true, Threads.Condition())
6+
StreamStore{T}() where T =
7+
new{T}(zeros(Int, 0), Dict{Int,Vector{T}}(),
8+
true, Threads.Condition())
79
end
810
tid() = Dagger.Sch.sch_handle().thunk_id.id
911
function uid()
@@ -16,7 +18,7 @@ function uid()
1618
end
1719
end
1820
end
19-
function Base.put!(store::StreamStore, @nospecialize(value))
21+
function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T
2022
@lock store.lock begin
2123
while length(store.waiters) == 0 && isopen(store)
2224
@dagdebug nothing :stream_put "[$(uid())] no waiters, not putting"
@@ -89,7 +91,7 @@ end
8991
mutable struct Stream{T} <: AbstractChannel{T}
9092
ref::Chunk
9193
function Stream{T}() where T
92-
store = tochunk(StreamStore())
94+
store = tochunk(StreamStore{T}())
9395
return new{T}(store)
9496
end
9597
end
@@ -157,13 +159,16 @@ end
157159
function initialize_streaming!(self_streams, spec, task)
158160
if !isa(spec.f, StreamingFunction)
159161
# Adapt called function for streaming and generate output Streams
160-
# FIXME: Infer type
161-
stream = Stream()
162+
T_old = Base.uniontypes(task.metadata.return_type)
163+
T_old = map(t->(t !== Union{} && t <: FinishedStreaming) ? only(t.parameters) : t, T_old)
164+
# We treat non-dominating error paths as unreachable
165+
T_old = filter(t->t !== Union{}, T_old)
166+
T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any
167+
stream = Stream{T}()
162168
self_streams[task.uid] = stream
163169

164170
spec.f = StreamingFunction(spec.f, stream)
165-
# FIXME: Generalize to other processors
166-
spec.options = merge(spec.options, (;occupancy=Dict(ThreadProc=>0)))
171+
spec.options = merge(spec.options, (;occupancy=Dict(Any=>0)))
167172

168173
# Register Stream globally
169174
remotecall_wait(1, task.uid, stream) do uid, stream

0 commit comments

Comments
 (0)