Skip to content

Commit 9a9a568

Browse files
committed
Simplify return of compile
1 parent 44d6410 commit 9a9a568

File tree

4 files changed

+48
-62
lines changed

4 files changed

+48
-62
lines changed

exla/lib/exla/defn.ex

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ defmodule EXLA.Defn do
3232

3333
@doc false
3434
def __stream__(key, input, acc, vars, fun, [args], options) do
35+
{debug?, options} = Keyword.pop(options, :debug, false)
3536
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
3637

3738
{client_name, compile_options} =
@@ -51,24 +52,26 @@ defmodule EXLA.Defn do
5152
comp_fun =
5253
&to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options)
5354

54-
{executable, used_inputs, {output, acc_output}, outfeed, extra, debug?} =
55+
{executable, used_inputs, {output, acc_output}, outfeed, input_typespecs} =
5556
compile(
5657
client,
57-
{:stream, key},
58+
key,
5859
vars,
5960
fun,
6061
compile_options,
6162
used_buffers,
6263
used_inputs,
6364
_stream = true,
65+
debug?,
6466
comp_fun
6567
)
6668

67-
{input_typespecs, input_indexes} = extra
69+
# Now discard the infeed from used inputs, similar to how it is done to buffers.
70+
# Note we discard all lazy transfers too, as they are not possible with streams.
71+
used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{}
6872

69-
# Also discard the stream inputs from used inputs, similar to how it is done to buffers
70-
# Note we discard all lazy transfers too, as they are not possible with streams
71-
used_inputs = Enum.sort(for {i, nil} <- used_inputs, i >= used_buffers, do: i)
73+
# And capture the typespecs for the infeed.
74+
input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end)
7275

7376
# Execution of streams requires the coordination of
7477
# multiple processes which is outlined below.
@@ -120,7 +123,6 @@ defmodule EXLA.Defn do
120123
outfeed_pid,
121124
input,
122125
input_typespecs,
123-
input_indexes,
124126
output,
125127
output_typespecs,
126128
acc_output
@@ -151,9 +153,6 @@ defmodule EXLA.Defn do
151153
{input_typespecs, used_typespecs} =
152154
Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end)
153155

154-
# Get all input indexes and shape
155-
input_indexes = Enum.map(input_typespecs, &elem(&1, 0))
156-
157156
# Drop all accumulator entries from used_typespecs as we will handle it separately.
158157
{acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length)
159158

@@ -166,13 +165,10 @@ defmodule EXLA.Defn do
166165
# The input will be read as part of the infeed.
167166
acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1))
168167
acc_typespec = List.to_tuple(acc_typespecs_l)
169-
170168
flag_typespec = Typespec.tensor({:pred, 8}, {})
171169

172170
args = EXLA.MLIR.Function.get_arguments(builder)
173-
174171
{token, [flag]} = Value.infeed(root_token, [flag_typespec])
175-
176172
init = [flag, token | args]
177173

178174
arg_typespecs = Enum.map(init, &Value.get_typespec/1)
@@ -186,11 +182,9 @@ defmodule EXLA.Defn do
186182
{body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs)
187183

188184
{acc, constant} = Enum.split(args, acc_length)
189-
190-
{indices, input_typespecs} = Enum.unzip(input_typespecs)
185+
{input_indices, input_typespecs} = Enum.unzip(input_typespecs)
191186
{token, input} = Value.infeed(token, input_typespecs)
192-
193-
input_params = Enum.zip(indices, input)
187+
input_params = Enum.zip(input_indices, input)
194188

195189
{%Outfeed{token: token} = outfeed, acc} =
196190
case expr do
@@ -226,9 +220,7 @@ defmodule EXLA.Defn do
226220

227221
# Emit the stream hook to signal loop output
228222
{token, [flag]} = Value.infeed(token, [flag_typespec])
229-
230223
Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant))
231-
232224
Function.pop_region(builder)
233225

234226
[_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init)
@@ -238,8 +230,7 @@ defmodule EXLA.Defn do
238230

239231
outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
240232
Value.func_return(builder, output)
241-
242-
{{input_typespecs, input_indexes}, outfeed}
233+
outfeed
243234
end
244235

245236
@doc false
@@ -249,6 +240,7 @@ defmodule EXLA.Defn do
249240

250241
@doc false
251242
def __compile__(key, vars, fun, options) do
243+
{debug?, options} = Keyword.pop(options, :debug, false)
252244
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
253245

254246
{client_name, compile_options} =
@@ -258,8 +250,8 @@ defmodule EXLA.Defn do
258250

259251
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))
260252

261-
{executable, used_inputs, outputs, outfeed, :ok, debug?} =
262-
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)
253+
{executable, used_inputs, outputs, outfeed, _input_typespecs?} =
254+
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)
263255

264256
fn [args] ->
265257
{time, lock} =
@@ -306,10 +298,8 @@ defmodule EXLA.Defn do
306298

307299
{res, cache} = recur_flatten(expr, state, new_cache(outfeed))
308300
outfeed = cache |> get_outfeed() |> Outfeed.close(function)
309-
310301
Value.func_return(function, res)
311-
312-
{:ok, outfeed}
302+
outfeed
313303
end
314304

315305
defp maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options)
@@ -367,6 +357,7 @@ defmodule EXLA.Defn do
367357
used_buffers,
368358
used_inputs,
369359
stream?,
360+
debug?,
370361
to_computation
371362
) do
372363
{{expr_cache_fun, comp_cache_fun}, options} =
@@ -379,8 +370,6 @@ defmodule EXLA.Defn do
379370
{{cache_fun, cache_fun}, options}
380371
end
381372

382-
{debug?, options} = Keyword.pop(options, :debug, false)
383-
384373
{args_key, reverse_args_identifiers} =
385374
Enum.map_reduce(vars, [], fn var, acc ->
386375
Nx.Defn.Composite.traverse(var, acc, fn
@@ -396,7 +385,7 @@ defmodule EXLA.Defn do
396385

397386
{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
398387
:timer.tc(fn ->
399-
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
388+
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
400389
expr = fun.(vars)
401390
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
402391
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
@@ -412,12 +401,10 @@ defmodule EXLA.Defn do
412401
end
413402

414403
{hooks, options} = Keyword.pop(options, :hooks, %{})
415-
416404
outfeed = Outfeed.new(hooks, defined_hooks)
417-
418405
comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options}
419406

420-
{comp_time, {evaled, {xla_time, executable, extra, outfeed}}} =
407+
{comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} =
421408
:timer.tc(fn ->
422409
comp_cache_fun.(comp_key, fn ->
423410
{reverse_inputs_and_typespecs, reverse_infeeds} =
@@ -430,7 +417,7 @@ defmodule EXLA.Defn do
430417

431418
inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs)
432419

433-
comp_arg_typespecs =
420+
comp_typespecs =
434421
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
435422

436423
outputs =
@@ -451,7 +438,7 @@ defmodule EXLA.Defn do
451438
|> then(&Typespec.tensor(&1.type, &1.shape))
452439
end)
453440

454-
EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder ->
441+
EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder ->
455442
# Only create the token when we know it will actually be
456443
# used, that is: streaming, lazy transfers or hooks
457444
outfeed =
@@ -464,25 +451,20 @@ defmodule EXLA.Defn do
464451
end
465452

466453
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
467-
468-
{extra, outfeed} =
469-
to_computation.(builder, expr, inputs_and_typespecs, outfeed)
454+
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)
470455

471456
{xla_time, executable} =
472457
:timer.tc(fn ->
473-
typespecs =
474-
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
475-
476458
EXLA.MLIR.Module.compile(
477459
builder.module,
478460
client,
479-
typespecs,
461+
comp_typespecs,
480462
builder.return_typespecs,
481463
options
482464
)
483465
end)
484466

485-
{:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}}
467+
{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
486468
end)
487469
end)
488470
end)
@@ -511,7 +493,7 @@ defmodule EXLA.Defn do
511493
end
512494

513495
outfeed = Outfeed.with_user_hooks(outfeed, hooks)
514-
{executable, used_inputs, outputs, outfeed, extra, debug?}
496+
{executable, used_inputs, outputs, outfeed, inputs_and_typespecs}
515497
end
516498

517499
defp us_to_ms(time), do: Float.round(time / 1000, 1)

exla/lib/exla/defn/buffers.ex

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ defmodule EXLA.Defn.Buffers do
3232

3333
@doc """
3434
Splits the given args by value and returns them as is.
35-
36-
Entries with a map entry are discarded.
3735
"""
3836
def split_by_value(args, %{} = map, callback) do
3937
{_i, left, right} =
4038
Enum.reduce(args, {0, [], []}, fn arg, {i, left, right} ->
4139
case map do
42-
%{^i => nil} -> {i + 1, [callback.(arg, i, nil) | left], right}
43-
%{^i => value} -> {i + 1, left, [callback.(arg, i, value) | right]}
44-
%{} -> {i + 1, left, right}
40+
%{^i => nil} ->
41+
{i + 1, [callback.(arg, i, nil) | left], right}
42+
43+
%{^i => value} ->
44+
{i + 1, left, [callback.(arg, i, value) | right]}
45+
46+
%{} ->
47+
{i + 1, left, right}
4548
end
4649
end)
4750

exla/lib/exla/defn/outfeed.ex

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,15 @@ defmodule EXLA.Defn.Outfeed do
120120
{infeeds, {compiled_hooks, token}} =
121121
entries
122122
|> List.keysort(1, :desc)
123-
|> Enum.map_reduce({compiled_hooks, token}, fn {pos, _, typespec},
124-
{compiled_hooks, token} ->
125-
next_flag = next_hook(compiled_hooks)
126-
compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec})
123+
|> Enum.map_reduce({compiled_hooks, token}, fn
124+
{pos, _, typespec}, {compiled_hooks, token} ->
125+
next_flag = next_hook(compiled_hooks)
126+
compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec})
127127

128-
token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token)
129-
{token, [input]} = Value.infeed(token, [typespec])
128+
token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token)
129+
{token, [input]} = Value.infeed(token, [typespec])
130130

131-
{{pos, input}, {compiled_hooks, token}}
131+
{{pos, input}, {compiled_hooks, token}}
132132
end)
133133

134134
%{outfeed | compiled_hooks: compiled_hooks, token: token, infeeds: infeeds}

exla/lib/exla/defn/stream.ex

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do
22
@moduledoc false
33

44
keys =
5-
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs, :send_indexes] ++
5+
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs] ++
66
[:recv, :recv_length, :done, :client, :device_id]
77

88
@derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]}
@@ -16,7 +16,6 @@ defmodule EXLA.Defn.Stream do
1616
outfeed,
1717
send,
1818
send_typespecs,
19-
send_indexes,
2019
recv,
2120
recv_typespecs,
2221
done
@@ -40,7 +39,6 @@ defmodule EXLA.Defn.Stream do
4039
lock: lock,
4140
send: send,
4241
send_typespecs: send_typespecs,
43-
send_indexes: send_indexes,
4442
recv: recv,
4543
recv_length: length(recv_typespecs),
4644
client: client,
@@ -64,15 +62,14 @@ defmodule EXLA.Defn.Stream do
6462
client: client,
6563
device_id: device_id,
6664
send: send,
67-
send_typespecs: send_typespecs,
68-
send_indexes: send_indexes
65+
send_typespecs: send_typespecs
6966
} = stream
7067

7168
if pid != self() do
7269
raise "EXLA streams require recv to be called from the process that started the stream"
7370
end
7471

75-
{template, buffers} = nx_to_io(data, send_indexes)
72+
{template, buffers} = nx_to_io(data, Enum.map(send_typespecs, &elem(&1, 0)))
7673

7774
unless Nx.compatible?(send, template) do
7875
raise ArgumentError, """
@@ -87,7 +84,11 @@ defmodule EXLA.Defn.Stream do
8784
end
8885

8986
pred = EXLA.Typespec.tensor({:pred, 8}, {})
90-
data_and_typespecs = Enum.zip(buffers, send_typespecs)
87+
88+
data_and_typespecs =
89+
Enum.zip_with(buffers, send_typespecs, fn buffer, {_index, typespec} ->
90+
{buffer, typespec}
91+
end)
9192

9293
:ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}])
9394
:ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs)

0 commit comments

Comments
 (0)