Skip to content

Commit 2e140b7

Browse files
committed
Clean up internal compile interface
1 parent e9b3d73 commit 2e140b7

File tree

1 file changed

+14
-45
lines changed

1 file changed

+14
-45
lines changed

exla/lib/exla/defn.ex

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,8 @@ defmodule EXLA.Defn do
3232

3333
@doc false
3434
def __stream__(key, input, acc, vars, fun, [args], options) do
35-
{debug?, options} = Keyword.pop(options, :debug, false)
3635
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
37-
38-
{client_name, compile_options} =
39-
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)
40-
41-
client = EXLA.Client.fetch!(client_name)
36+
debug? = Keyword.get(compile_options, :debug, false)
4237
compile_options = Keyword.put(compile_options, :lazy_transfers, :never)
4338

4439
input_length = length(Nx.Defn.Composite.flatten_list([input]))
@@ -50,21 +45,10 @@ defmodule EXLA.Defn do
5045
used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1)
5146

5247
comp_fun =
53-
&to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options)
48+
&to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options)
5449

5550
{executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} =
56-
compile(
57-
client,
58-
key,
59-
vars,
60-
fun,
61-
compile_options,
62-
used_buffers,
63-
used_inputs,
64-
_stream = true,
65-
debug?,
66-
comp_fun
67-
)
51+
compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun)
6852

6953
# Now discard the infeed from used inputs, similar to how it is done to buffers.
7054
# Note we discard all lazy transfers too, as they are not possible with streams.
@@ -136,13 +120,13 @@ defmodule EXLA.Defn do
136120
end
137121

138122
defp to_stream_computation(
139-
client,
140123
input_length,
141124
acc_length,
142125
%Function{} = builder,
143126
expr,
144127
used_typespecs,
145128
outfeed,
129+
client,
146130
options
147131
) do
148132
%{token: root_token, infeeds: []} = outfeed
@@ -237,18 +221,12 @@ defmodule EXLA.Defn do
237221

238222
@doc false
239223
def __compile__(key, vars, fun, options) do
240-
{debug?, options} = Keyword.pop(options, :debug, false)
241224
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
242-
243-
{client_name, compile_options} =
244-
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)
245-
246-
client = EXLA.Client.fetch!(client_name)
247-
248-
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))
225+
debug? = Keyword.get(compile_options, :debug, false)
226+
callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options)
249227

250228
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
251-
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)
229+
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)
252230

253231
fn [args] ->
254232
{time, lock} =
@@ -270,14 +248,12 @@ defmodule EXLA.Defn do
270248
end
271249
end
272250

273-
defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do
251+
defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
274252
params =
275253
Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg ->
276254
{pos, arg}
277255
end)
278256

279-
client = Keyword.fetch!(options, :client)
280-
281257
unless client do
282258
raise ArgumentError, "missing client"
283259
end
@@ -342,22 +318,15 @@ defmodule EXLA.Defn do
342318

343319
## Compile
344320

345-
defp compile(
346-
client,
347-
key,
348-
vars,
349-
fun,
350-
options,
351-
used_buffers,
352-
used_inputs,
353-
stream?,
354-
debug?,
355-
to_computation
356-
) do
321+
defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do
357322
{cache, options} = Keyword.pop(options, :cache, true)
358323
{hooks, options} = Keyword.pop(options, :hooks, %{})
324+
{debug?, options} = Keyword.pop(options, :debug, false)
359325
{lazy_transfers, options} = Keyword.pop(options, :lazy_transfers, :opt_in)
360326

327+
{client_name, options} = Keyword.pop_lazy(options, :client, &EXLA.Client.default_name/0)
328+
client = EXLA.Client.fetch!(client_name)
329+
361330
{args_key, reverse_args_identifiers} =
362331
Enum.map_reduce(vars, [], fn var, acc ->
363332
Nx.Defn.Composite.traverse(var, acc, fn
@@ -453,7 +422,7 @@ defmodule EXLA.Defn do
453422
end
454423

455424
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
456-
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)
425+
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed, client)
457426

458427
{xla_time, executable} =
459428
:timer.tc(fn ->

0 commit comments

Comments
 (0)