Skip to content

Commit c7cdd84

Browse files
Create HLO token in the computation only when needed (#1494)
1 parent 773f4d6 commit c7cdd84

File tree

1 file changed

+97
-35
lines changed

1 file changed

+97
-35
lines changed

exla/lib/exla/defn.ex

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ defmodule EXLA.Defn do
6060
compile_options,
6161
used_buffers,
6262
used_inputs,
63+
_stream = true,
6364
comp_fun
6465
)
6566

@@ -258,7 +259,7 @@ defmodule EXLA.Defn do
258259
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))
259260

260261
{executable, used_inputs, outputs, outfeed, :ok, debug?} =
261-
compile(client, key, vars, fun, compile_options, 0, [], callback)
262+
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)
262263

263264
fn [args] ->
264265
{time, lock} =
@@ -357,7 +358,17 @@ defmodule EXLA.Defn do
357358

358359
## Compile
359360

360-
defp compile(client, key, vars, fun, options, used_buffers, used_inputs, to_computation) do
361+
defp compile(
362+
client,
363+
key,
364+
vars,
365+
fun,
366+
options,
367+
used_buffers,
368+
used_inputs,
369+
stream?,
370+
to_computation
371+
) do
361372
{{expr_cache_fun, comp_cache_fun}, options} =
362373
case Keyword.pop(options, :cache, true) do
363374
{true, options} ->
@@ -385,7 +396,7 @@ defmodule EXLA.Defn do
385396

386397
{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
387398
:timer.tc(fn ->
388-
expr_cache_fun.({key, args_key}, fn ->
399+
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
389400
expr = fun.(vars)
390401
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
391402
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
@@ -432,10 +443,16 @@ defmodule EXLA.Defn do
432443
end)
433444

434445
EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder ->
446+
# Only create the token when we know it will actually be
447+
# used, that is: streaming, lazy transfers or hooks
435448
outfeed =
436-
outfeed
437-
|> Outfeed.with_token(Value.create_token(builder))
438-
|> Outfeed.add_infeeds(builder, reverse_infeeds)
449+
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
450+
outfeed
451+
|> Outfeed.with_token(Value.create_token(builder))
452+
|> Outfeed.add_infeeds(builder, reverse_infeeds)
453+
else
454+
outfeed
455+
end
439456

440457
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
441458

@@ -520,19 +537,30 @@ defmodule EXLA.Defn do
520537
cache
521538
) do
522539
[initial_arg, _arg, pred, body] = args
523-
initial_with_token = {get_token(cache), initial_arg}
524540

525-
{initial, cache} = recur_composite(initial_with_token, state, cache)
541+
initial =
542+
if token = get_token(cache) do
543+
{token, initial_arg}
544+
else
545+
initial_arg
546+
end
547+
548+
{initial, cache} = recur_composite(initial, state, cache)
526549

527550
{pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache)
528551
{body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache)
529552

530-
[token | results] =
553+
results =
531554
Value.while(function, pred_computation, body_computation, List.flatten(initial))
532555

533-
result = wrap_tuple_result(results, initial_arg)
534-
535-
{result, update_token(cache, token)}
556+
if get_token(cache) do
557+
[token | results] = results
558+
result = wrap_tuple_result(results, initial_arg)
559+
{result, update_token(cache, token)}
560+
else
561+
result = wrap_tuple_result(results, initial_arg)
562+
{result, cache}
563+
end
536564
end
537565

538566
defp cached_recur_operator(:cond, %T{data: %Expr{args: args}} = t, state, cache) do
@@ -688,16 +716,19 @@ defmodule EXLA.Defn do
688716
{computation, cache}
689717

690718
%{} ->
691-
{computation, cache} = token_computation("optional", call_args, expr, state, cache)
719+
{computation, cache} = optional_computation("optional", call_args, expr, state, cache)
692720
{computation, Map.put(cache, key, computation)}
693721
end
694722

695-
typespecs = [Typespec.token() | container_to_typespecs(expr)]
696-
697-
[token | result] =
698-
Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs)
699-
700-
{wrap_tuple_result(result, expr), update_token(cache, token)}
723+
if token = get_token(cache) do
724+
typespecs = [Typespec.token() | container_to_typespecs(expr)]
725+
[token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs)
726+
{wrap_tuple_result(result, expr), update_token(cache, token)}
727+
else
728+
typespecs = container_to_typespecs(expr)
729+
result = Value.call(state.builder, call_args, call_body, typespecs)
730+
{wrap_tuple_result(result, expr), cache}
731+
end
701732
end
702733

703734
defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do
@@ -1553,7 +1584,17 @@ defmodule EXLA.Defn do
15531584
defp mlir_while_computation(expr, initial, type, state, cache) do
15541585
arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1)
15551586

1556-
{region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs)
1587+
{region, args} = Function.push_region(state.builder, arg_typespecs)
1588+
1589+
outer_token = get_token(cache)
1590+
1591+
{inner_token, arg_params} =
1592+
if outer_token do
1593+
[arg_token | arg_params] = args
1594+
{arg_token, arg_params}
1595+
else
1596+
{nil, args}
1597+
end
15571598

15581599
params = Enum.with_index(arg_params, &{&2, &1})
15591600

@@ -1570,11 +1611,15 @@ defmodule EXLA.Defn do
15701611
expr
15711612
end
15721613

1573-
{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, arg_token))
1614+
{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, inner_token))
15741615

15751616
res =
15761617
if type == :with_token do
1577-
[get_token(comp_cache) | List.flatten(res)]
1618+
if outer_token do
1619+
[get_token(comp_cache) | List.flatten(res)]
1620+
else
1621+
List.flatten(res)
1622+
end
15781623
else
15791624
Enum.map(res, &to_type(&1, type))
15801625
end
@@ -1585,21 +1630,34 @@ defmodule EXLA.Defn do
15851630
{region, merge_outfeed(cache, comp_cache)}
15861631
end
15871632

1588-
defp token_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
1633+
defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
15891634
%Function{module: module, name: name} = subbuilder(state.builder, name)
15901635

1591-
token_typespec = Typespec.token()
15921636
arg_typespecs = Enum.map(args, &Value.get_typespec/1)
15931637
out_typespecs = container_to_typespecs(expr)
15941638

1595-
function =
1596-
EXLA.MLIR.Module.add_function(module, name, [token_typespec | arg_typespecs], [
1597-
token_typespec | out_typespecs
1598-
])
1639+
outer_token = get_token(cache)
1640+
token_typespec = Typespec.token()
1641+
1642+
{arg_typespecs, out_typespecs} =
1643+
if outer_token do
1644+
{[token_typespec | arg_typespecs], [token_typespec | out_typespecs]}
1645+
else
1646+
{arg_typespecs, out_typespecs}
1647+
end
15991648

1600-
[arg_token | tail] = EXLA.MLIR.Function.get_arguments(function)
1649+
function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs)
1650+
args = EXLA.MLIR.Function.get_arguments(function)
16011651

1602-
params = Enum.with_index(tail, fn param, i -> {i, param} end)
1652+
{inner_token, args} =
1653+
if outer_token do
1654+
[arg_token | args] = args
1655+
{arg_token, args}
1656+
else
1657+
{nil, args}
1658+
end
1659+
1660+
params = Enum.with_index(args, fn param, i -> {i, param} end)
16031661

16041662
state = %{
16051663
state
@@ -1608,9 +1666,13 @@ defmodule EXLA.Defn do
16081666
scope_ids: Tree.scope_ids(expr)
16091667
}
16101668

1611-
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token))
1669+
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))
16121670

1613-
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
1671+
if outer_token do
1672+
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
1673+
else
1674+
Value.return(function, List.flatten(res))
1675+
end
16141676

16151677
{function, merge_outfeed(cache, comp_cache)}
16161678
end
@@ -1786,10 +1848,10 @@ defmodule EXLA.Defn do
17861848

17871849
out_typespecs = container_to_typespecs(on_true)
17881850

1789-
in_token = get_token(cache)
1851+
outer_token = get_token(cache)
17901852

17911853
result_typespecs =
1792-
if in_token do
1854+
if outer_token do
17931855
[Typespec.token() | out_typespecs]
17941856
else
17951857
out_typespecs
@@ -1799,7 +1861,7 @@ defmodule EXLA.Defn do
17991861
{false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache)
18001862
if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs)
18011863

1802-
if in_token do
1864+
if outer_token do
18031865
[token | results] = if_results
18041866
{wrap_tuple_result(results, on_true), update_token(cache, token)}
18051867
else

0 commit comments

Comments
 (0)