Skip to content

Commit d43a262

Browse files
committed
fixup! fixup! Add streaming API
1 parent 78146e6 commit d43a262

File tree

2 files changed

+104
-47
lines changed

2 files changed

+104
-47
lines changed

src/eager_thunk.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ function Base.fetch(t::EagerThunk; raw=false)
6868
throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `EagerThunk`"))
6969
end
7070
stream = task_to_stream(t.uid)
71-
if stream !== nothing
71+
if stream isa Stream
7272
add_waiters!(stream, [0])
7373
end
7474
try
7575
return fetch(t.future; raw)
7676
finally
77-
if stream !== nothing
77+
if stream isa Stream
7878
remove_waiters!(stream, [0])
7979
end
8080
end

src/stream.jl

Lines changed: 102 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ function Base.put!(store::StreamStore{T}, @nospecialize(value::T)) where T
3030
end
3131
@dagdebug nothing :stream_put "[$(uid())] adding $value"
3232
for buffer in values(store.buffers)
33-
#elem = StreamElement(value)
3433
push!(buffer, value)
3534
end
3635
notify(store.lock)
@@ -139,11 +138,31 @@ end
139138
remove_waiters!(stream::Stream, waiter::Integer) =
140139
remove_waiters!(stream::Stream, Int[waiter])
141140

141+
struct NullStream end
142+
Base.put!(ns::NullStream, x) = nothing
143+
Base.take!(ns::NullStream) = throw(ConcurrencyViolationError("Cannot `take!` from a `NullStream`"))
144+
145+
mutable struct StreamWrapper{S}
146+
stream::S
147+
open::Bool
148+
StreamWrapper(stream::S) where S = new{S}(stream, true)
149+
end
150+
Base.isopen(sw::StreamWrapper) = sw.open
151+
Base.close(sw::StreamWrapper) = (sw.open = false;)
152+
function Base.put!(sw::StreamWrapper, x)
153+
isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed))
154+
put!(sw.stream, x)
155+
end
156+
function Base.take!(sw::StreamWrapper)
157+
isopen(sw) || throw(InvalidStateException("Stream is closed.", :closed))
158+
take!(sw.stream)
159+
end
160+
142161
struct StreamingTaskQueue <: AbstractTaskQueue
143162
tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}
144-
self_streams::Dict{UInt,Stream}
163+
self_streams::Dict{UInt,Any}
145164
StreamingTaskQueue() = new(Pair{EagerTaskSpec,EagerThunk}[],
146-
Dict{UInt,Stream}())
165+
Dict{UInt,Any}())
147166
end
148167

149168
function enqueue!(queue::StreamingTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk})
@@ -164,7 +183,20 @@ function initialize_streaming!(self_streams, spec, task)
164183
# We treat non-dominating error paths as unreachable
165184
T_old = filter(t->t !== Union{}, T_old)
166185
T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any
167-
stream = Stream{T}()
186+
if haskey(spec.options, :stream)
187+
if spec.options.stream !== nothing
188+
# Use the user-provided stream
189+
@warn "Replace StreamWrapper with Stream" maxlog=1
190+
stream = StreamWrapper(spec.options.stream)
191+
else
192+
# Use a non-readable, non-writing stream
193+
stream = StreamWrapper(NullStream())
194+
end
195+
spec.options = NamedTuple(filter(opt -> opt[1] != :stream, Base.pairs(spec.options)))
196+
else
197+
# Create a built-in Stream object
198+
stream = Stream{T}()
199+
end
168200
self_streams[task.uid] = stream
169201

170202
spec.f = StreamingFunction(spec.f, stream)
@@ -190,56 +222,33 @@ function spawn_streaming(f::Base.Callable)
190222
end
191223

192224
struct FinishedStreaming{T}
193-
value::T
225+
value::Union{Some{T},Nothing}
194226
end
195-
finish_streaming(value=nothing) = FinishedStreaming(value)
227+
finish_streaming(value) = FinishedStreaming{Any}(Some{T}(value))
228+
finish_streaming() = FinishedStreaming{Union{}}(nothing)
196229

197-
struct StreamingFunction{F, T}
230+
struct StreamingFunction{F, S}
198231
f::F
199-
stream::Stream{T}
232+
stream::S
200233
end
201234
function (sf::StreamingFunction)(args...; kwargs...)
202235
@nospecialize sf args kwargs
203236
result = nothing
204-
stream_args = Base.mapany(identity, args)
205-
stream_kwargs = Base.mapany(identity, kwargs)
206237
thunk_id = tid()
207-
# FIXME: Fetch from worker 1
208-
uid = lock(Sch.EAGER_ID_MAP) do id_map
209-
for (uid, otid) in id_map
210-
if thunk_id == otid
211-
return uid
238+
@warn "Fetch from worker 1 more efficiently" maxlog=1
239+
uid = remotecall_fetch(1, thunk_id) do thunk_id
240+
lock(Sch.EAGER_ID_MAP) do id_map
241+
for (uid, otid) in id_map
242+
if thunk_id == otid
243+
return uid
244+
end
212245
end
213246
end
214247
end
215248
try
216-
while true
217-
# Get values from Stream args/kwargs
218-
for (idx, arg) in enumerate(args)
219-
if arg isa Stream
220-
stream_args[idx] = take!(arg, uid)
221-
end
222-
end
223-
for (idx, (pos, arg)) in enumerate(kwargs)
224-
if arg isa Stream
225-
stream_kwargs[idx] = pos => take!(arg, uid)
226-
end
227-
end
228-
229-
# Run a single cycle of f
230-
stream_result = sf.f(stream_args...; stream_kwargs...)
231-
232-
# Exit streaming on graceful request
233-
if stream_result isa FinishedStreaming
234-
return stream_result.value
235-
end
236-
237-
# Put the result into the output stream
238-
put!(sf.stream, stream_result)
239-
240-
# Allow other tasks to run
241-
yield()
242-
end
249+
kwarg_names = map(name->Val{name}(), map(first, (kwargs...,)))
250+
kwarg_values = map(last, (kwargs...,))
251+
return stream!(sf, uid, (args...,), kwarg_names, kwarg_values)
243252
finally
244253
# Remove ourself as a waiter for upstream Streams
245254
streams = Set{Stream}()
@@ -263,9 +272,55 @@ function (sf::StreamingFunction)(args...; kwargs...)
263272
close(sf.stream)
264273
end
265274
end
275+
# N.B We specialize to minimize/eliminate allocations
276+
function stream!(sf::StreamingFunction, uid,
277+
args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple)
278+
while true
279+
#@time begin
280+
# Get values from Stream args/kwargs
281+
stream_args = _stream_take_values!(args)
282+
stream_kwarg_values = _stream_take_values!(kwarg_values)
283+
stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values)
284+
285+
# Run a single cycle of f
286+
stream_result = sf.f(stream_args...; stream_kwargs...)
287+
288+
# Exit streaming on graceful request
289+
if stream_result isa FinishedStreaming
290+
@info "Terminating!"
291+
if stream_result.value !== nothing
292+
value = something(stream_result.value)
293+
put!(sf.stream, value)
294+
return value
295+
end
296+
return nothing
297+
end
266298

267-
# FIXME: Ensure this gets cleaned up
268-
const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Stream}())
299+
# Put the result into the output stream
300+
put!(sf.stream, stream_result)
301+
#end
302+
end
303+
end
304+
function _stream_take_values!(args)
305+
return ntuple(length(args)) do idx
306+
arg = args[idx]
307+
if arg isa Stream
308+
take!(arg, uid)
309+
elseif arg isa Union{AbstractChannel,RemoteChannel,StreamWrapper} # FIXME: Use trait query
310+
take!(arg)
311+
else
312+
arg
313+
end
314+
end
315+
end
316+
@inline @generated function _stream_namedtuple(kwarg_names::Tuple,
317+
stream_kwarg_values::Tuple)
318+
name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...)
319+
NT = :(NamedTuple{$name_ex,$stream_kwarg_values})
320+
return :($NT(stream_kwarg_values))
321+
end
322+
323+
const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}())
269324
function task_to_stream(uid::UInt)
270325
if myid() != 1
271326
return remotecall_fetch(task_to_stream, 1, uid)
@@ -310,7 +365,9 @@ function finalize_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}, self
310365
# Adjust waiter count of Streams with dependencies
311366
for (uid, waiters) in stream_waiter_changes
312367
stream = task_to_stream(uid)
313-
add_waiters!(stream, waiters)
368+
if stream isa Stream # FIXME: Use trait query
369+
add_waiters!(stream, waiters)
370+
end
314371
end
315372
end
316373

0 commit comments

Comments
 (0)