Skip to content

Commit c82702b

Browse files
committed
Remove Nx.Defn.stream and Nx.Stream
1 parent 3a92566 commit c82702b

File tree

14 files changed

+9
-1106
lines changed

14 files changed

+9
-1106
lines changed

exla/lib/exla.ex

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -297,65 +297,6 @@ defmodule EXLA do
297297
Nx.Defn.compile(function, args, Keyword.put(options, :compiler, EXLA))
298298
end
299299

300-
@doc """
301-
Starts streaming the given anonymous function with just-in-time
302-
compilation.
303-
304-
At least two arguments are expected:
305-
306-
1. The first argument is a tensor template of the data to
307-
be streamed in
308-
309-
2. The second argument is a tensor with the stream initial state
310-
311-
The streaming function must return a two element tuple, the
312-
first element is the data to be sent and the second is the
313-
accumulator.
314-
315-
For each streamed chunk, you must call `Nx.Stream.send/2` and
316-
`Nx.Stream.recv/1`. You don't need to call `recv` immediately
317-
after `send`, but doing so can be a useful mechanism to provide
318-
backpressure. Once all chunks are sent, you must use `Nx.Stream.done/1`
319-
to receive the accumulated result. Let's see an example:
320-
321-
defmodule Streamed do
322-
import Nx.Defn
323-
324-
defn sum(tensor, acc) do
325-
{acc, tensor + acc}
326-
end
327-
end
328-
329-
Now let's invoke it:
330-
331-
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0])
332-
333-
for i <- 1..5 do
334-
Nx.Stream.send(stream, i)
335-
IO.inspect {:chunk, Nx.Stream.recv(stream)}
336-
end
337-
338-
IO.inspect {:result, Nx.Stream.done(stream)}
339-
340-
It will print:
341-
342-
{:chunk, 0}
343-
{:chunk, 1}
344-
{:chunk, 2}
345-
{:chunk, 3}
346-
{:chunk, 4}
347-
{:result, 5}
348-
349-
**Note:** While any process can call `Nx.Stream.send/2`, EXLA
350-
expects the process that starts the streaming to be the one
351-
calling `Nx.Stream.recv/1` and `Nx.Stream.done/1`.
352-
353-
See `jit/2` for supported options.
354-
"""
355-
def stream(function, args, options \\ []) do
356-
Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA))
357-
end
358-
359300
@doc ~S'''
360301
Takes in a function, the argument templates and the compilation
361302
options and returns the textual representation of the MLIR module.
@@ -442,31 +383,6 @@ defmodule EXLA do
442383
{:cached?, bool} -> bool
443384
end
444385

445-
@doc """
446-
Checks if the JIT compilation of stream with
447-
args is cached.
448-
449-
Note that hooks are part of the cache, and
450-
therefore they must be included in the options.
451-
452-
## Examples
453-
454-
iex> left = Nx.tensor(1, type: {:u, 8})
455-
iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
456-
iex> fun = fn x, acc -> {acc, Nx.add(x, acc)} end
457-
iex> stream = EXLA.stream(fun, [left, right])
458-
iex> Nx.Stream.done(stream)
459-
iex> EXLA.stream_cached?(fun, [left, right])
460-
true
461-
iex> EXLA.stream_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
462-
false
463-
"""
464-
def stream_cached?(function, args, options \\ []) do
465-
stream(function, args, [{EXLA, cached_check()} | options])
466-
catch
467-
{:cached?, bool} -> bool
468-
end
469-
470386
defp cached_check do
471387
expr_cache_fun = fn key, _callback ->
472388
if res = EXLA.Defn.LockedCache.get(key) do
@@ -489,9 +405,6 @@ defmodule EXLA do
489405
@impl true
490406
defdelegate __jit__(key, vars, fun, args, opts), to: EXLA.Defn
491407

492-
@impl true
493-
defdelegate __stream__(key, input, acc, vars, fun, args, opts), to: EXLA.Defn
494-
495408
@impl true
496409
defdelegate __partitions_options__(opts), to: EXLA.Defn
497410

exla/lib/exla/defn.ex

Lines changed: 6 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -30,190 +30,6 @@ defmodule EXLA.Defn do
3030
{EXLA.Backend, [client: client_name, device_id: device_id]}
3131
end
3232

33-
@doc false
34-
def __stream__(key, input, acc, vars, fun, [args], options) do
35-
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
36-
debug? = Keyword.get(compile_options, :debug, false)
37-
compile_options = Keyword.put(compile_options, :lazy_transfers, :never)
38-
39-
input_length = length(Nx.Defn.Composite.flatten_list([input]))
40-
acc_length = length(Nx.Defn.Composite.flatten_list([acc]))
41-
42-
# The input vars should not be converted to buffers as they come from infeed
43-
# Accs are always considered as used
44-
used_buffers = input_length
45-
used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1)
46-
47-
comp_fun =
48-
&to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options)
49-
50-
{executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} =
51-
compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun)
52-
53-
# Now discard the infeed from used inputs, similar to how it is done to buffers.
54-
# Note we discard all lazy transfers too, as they are not possible with streams.
55-
used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{}
56-
57-
# And capture the typespecs for the infeed.
58-
input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end)
59-
60-
# Execution of streams requires the coordination of
61-
# multiple processes which is outlined below.
62-
63-
# First, we get a lock on the executable, because we want
64-
# to avoid transfer to the device unless we know we are
65-
# ready to use the device.
66-
{time, lock} =
67-
:timer.tc(fn ->
68-
EXLA.Defn.Lock.lock(run_key(executable))
69-
end)
70-
71-
debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")
72-
73-
{time, streams} =
74-
:timer.tc(fn ->
75-
buffers =
76-
EXLA.Defn.Buffers.filter_by_indexes(args, used_inputs, fn arg, _ ->
77-
EXLA.Defn.Buffers.from_nx!(arg, executable)
78-
end)
79-
80-
# Now that we have transferred to device, we spawn a runner process
81-
# to execute the stream. We use a runner instead of a task to avoid
82-
# leaking messages in the inbox. We also don't use a supervisor
83-
# to keep them linked, which is safe because the agent is not used
84-
# outside the scope of the current process.
85-
#
86-
# Finally, note the runner cannot start immediately, we need to
87-
# setup the outfeed reader and register the on_unlock callback
88-
# that cancels the stream atomically. This is done inside
89-
# EXLA.Defn.Stream.run.
90-
{:ok, runner} =
91-
EXLA.Defn.Runner.start_link(lock, fn ->
92-
EXLA.Executable.run(executable, [buffers], run_options)
93-
end)
94-
95-
# The outfeed reader will redirect all outputs with flag 1 to the current
96-
# process. Once flag 0 is emitted, we know the stream is done.
97-
{output_typespecs, outfeed} = Outfeed.configure_stream_hook(outfeed, self(), lock)
98-
{:ok, outfeed_pid} = Outfeed.start_child(executable, outfeed, Process.group_leader())
99-
100-
stream =
101-
EXLA.Defn.Stream.run(
102-
executable,
103-
lock,
104-
runner,
105-
outfeed_pid,
106-
input,
107-
input_typespecs,
108-
output,
109-
output_typespecs,
110-
acc_output
111-
)
112-
113-
[stream]
114-
end)
115-
116-
debug? &&
117-
Logger.debug("EXLA stream start on device #{executable.device_id} in #{us_to_ms(time)}ms")
118-
119-
streams
120-
end
121-
122-
defp to_stream_computation(
123-
input_length,
124-
acc_length,
125-
%Function{} = builder,
126-
expr,
127-
used_typespecs,
128-
outfeed,
129-
client,
130-
options
131-
) do
132-
%{token: root_token, infeeds: []} = outfeed
133-
134-
{input_typespecs, used_typespecs} =
135-
Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end)
136-
137-
# Drop all accumulator entries from used_typespecs as we will handle it separately.
138-
{acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length)
139-
140-
# The stream loop will be a three element tuple:
141-
#
142-
# The result of calling infeed.
143-
# The looping accumulator.
144-
# The looping constants.
145-
#
146-
# The input will be read as part of the infeed.
147-
acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1))
148-
acc_typespec = List.to_tuple(acc_typespecs_l)
149-
flag_typespec = Typespec.tensor({:pred, 8}, {})
150-
151-
args = EXLA.MLIR.Function.get_arguments(builder)
152-
{token, [flag]} = Value.infeed(root_token, [flag_typespec])
153-
init = [flag, token | args]
154-
155-
arg_typespecs = Enum.map(init, &Value.get_typespec/1)
156-
{pred_computation, [flag | _]} = Function.push_region(builder, arg_typespecs)
157-
typespec = Typespec.tensor({:pred, 8}, {})
158-
r0 = Value.constant(builder, [1], typespec)
159-
pred_op = Value.equal(flag, r0, typespec)
160-
Value.return(builder, [pred_op])
161-
Function.pop_region(builder)
162-
163-
{body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs)
164-
165-
{acc, constant} = Enum.split(args, acc_length)
166-
{input_indices, input_typespecs} = Enum.unzip(input_typespecs)
167-
{token, input} = Value.infeed(token, input_typespecs)
168-
input_params = Enum.zip(input_indices, input)
169-
170-
{%Outfeed{token: token} = outfeed, acc} =
171-
case expr do
172-
{output_expr, acc_expr} ->
173-
acc_params =
174-
Enum.map(acc_typespecs, fn {pos, _typespec} ->
175-
{pos, Enum.fetch!(acc, pos - input_length)}
176-
end)
177-
178-
constant_params =
179-
Enum.with_index(used_typespecs, fn {pos, _typespec}, index ->
180-
{pos, Enum.fetch!(constant, index)}
181-
end)
182-
183-
state = %{
184-
client: client,
185-
builder: builder,
186-
precision: Keyword.get(options, :precision, :default),
187-
params: Map.new(input_params ++ acc_params ++ constant_params),
188-
scope_ids: Tree.scope_ids(expr)
189-
}
190-
191-
outfeed = Outfeed.with_token(outfeed, token)
192-
{output, cache} = recur_flatten(output_expr, state, new_cache(outfeed))
193-
{acc, cache} = recur_flatten(acc_expr, state, cache)
194-
outfeed = cache |> get_outfeed() |> Outfeed.add_stream_hook(builder, output)
195-
{outfeed, acc}
196-
197-
_ ->
198-
raise "expected the function given to Nx.stream/3 to return a two-element tuple, got: " <>
199-
inspect(expr)
200-
end
201-
202-
# Emit the stream hook to signal loop output
203-
{token, [flag]} = Value.infeed(token, [flag_typespec])
204-
Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant))
205-
Function.pop_region(builder)
206-
207-
[_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init)
208-
209-
acc = Enum.take(results, acc_length)
210-
output = wrap_tuple_result(acc, acc_typespec)
211-
212-
outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
213-
Value.func_return(builder, output)
214-
outfeed
215-
end
216-
21733
@doc false
21834
def __jit__(key, vars, fun, args_list, options) do
21935
__compile__(key, vars, fun, options).(args_list)
@@ -223,10 +39,10 @@ defmodule EXLA.Defn do
22339
def __compile__(key, vars, fun, options) do
22440
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
22541
debug? = Keyword.get(compile_options, :debug, false)
226-
callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options)
42+
callback = &to_computation(&1, &2, &3, &4, &5, compile_options)
22743

22844
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
229-
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)
45+
compile(key, vars, fun, compile_options, 0, [], callback)
23046

23147
if compile_options[:module_compilation] == :to_mlir do
23248
throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs})
@@ -252,7 +68,7 @@ defmodule EXLA.Defn do
25268
end
25369
end
25470

255-
defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
71+
defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
25672
params =
25773
Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg ->
25874
{pos, arg}
@@ -322,7 +138,7 @@ defmodule EXLA.Defn do
322138

323139
## Compile
324140

325-
defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do
141+
defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation) do
326142
{cache, options} = Keyword.pop(options, :cache, true)
327143
{hooks, options} = Keyword.pop(options, :hooks, %{})
328144
{debug?, options} = Keyword.pop(options, :debug, false)
@@ -361,7 +177,7 @@ defmodule EXLA.Defn do
361177

362178
{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
363179
:timer.tc(fn ->
364-
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
180+
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
365181
expr = fun.(vars)
366182
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
367183
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
@@ -395,15 +211,6 @@ defmodule EXLA.Defn do
395211
comp_typespecs =
396212
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
397213

398-
outputs =
399-
if stream? do
400-
# The computation returns the final accumulator value
401-
{_chunk_result, acc} = outputs
402-
acc
403-
else
404-
outputs
405-
end
406-
407214
out_typespecs =
408215
[outputs]
409216
|> Nx.Defn.Composite.flatten_list()
@@ -417,7 +224,7 @@ defmodule EXLA.Defn do
417224
# Only create the token when we know it will actually be
418225
# used, that is: streaming, lazy transfers or hooks
419226
outfeed =
420-
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
227+
if reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
421228
outfeed
422229
|> Outfeed.with_token(Value.create_token(builder))
423230
|> Outfeed.add_infeeds(builder, reverse_infeeds)

exla/lib/exla/defn/outfeed.ex

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -153,23 +153,6 @@ defmodule EXLA.Defn.Outfeed do
153153
end
154154
end
155155

156-
@doc """
157-
Adds a stream hook.
158-
159-
Used by streams. Only one is allowed. Requires configuration.
160-
"""
161-
def add_stream_hook(%Outfeed{} = outfeed, builder, tuple) do
162-
{outfeed, flag, typespecs} = outfeed_flat_tuple(outfeed, builder, tuple)
163-
# We don't know the pid+ref pair for the stream, so we store it
164-
# under a special key called :stream and revert to the flag once configured
165-
put_in(outfeed.compiled_hooks[:stream], {flag, typespecs})
166-
end
167-
168-
def configure_stream_hook(%Outfeed{} = outfeed, pid, ref) when is_pid(pid) do
169-
{{flag, typespecs}, outfeed} = pop_in(outfeed.compiled_hooks[:stream])
170-
{typespecs, put_in(outfeed.compiled_hooks[flag], {:stream, typespecs, pid, ref})}
171-
end
172-
173156
@doc """
174157
Closes the outfeed at the end of a pipeline.
175158
@@ -254,10 +237,6 @@ defmodule EXLA.Defn.Outfeed do
254237
EXLA.Client.to_infeed(client, device_id, [{data, data_typespec}])
255238
loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds)
256239

257-
{:stream, typespecs, recv_pid, recv_ref} ->
258-
:ok = EXLA.Client.from_outfeed(client, device_id, typespecs, recv_pid, recv_ref)
259-
loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds)
260-
261240
{:function, typespecs, name, template} ->
262241
fun = Map.fetch!(hooks, name)
263242
length = length(typespecs)

0 commit comments

Comments
 (0)