Skip to content

Commit c5f1651

Browse files
committed
Add cache: binary() to EXLA
1 parent c210892 commit c5f1651

File tree

7 files changed

+249
-165
lines changed

7 files changed

+249
-165
lines changed

exla/lib/exla.ex

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ defmodule EXLA do
224224
It accepts the same option as `Nx.Defn.jit/2` plus:
225225
226226
* `:cache` - cache the results of compilation, defaults to `true`.
227+
You may disable it by setting it to `false`. You can also set it
228+
to a binary, representing a filesystem path to store the cache.
229+
EXLA will ensure the arguments and parameters across invocations
230+
have the same shape, but it is ultimately your responsibility
231+
to provide a unique cache path.
227232
228233
* `:client` - an atom representing the client to use. The default
229234
client is chosen on this order: `:cuda`, `:rocm`, `:tpu`, and `:host`.
@@ -275,22 +280,7 @@ defmodule EXLA do
275280
The backend will then block only when trying to read the data
276281
or when passing it to another operation.
277282
278-
## Options
279-
280-
It accepts the same option as `Nx.Defn.compile/3` plus:
281-
282-
* `:debug` - print compile and debugging information, defaults to `false`.
283-
284-
* `:cache` - cache the results of compilation, defaults to `true`.
285-
You can set it to false if you plan to compile the function only
286-
once and store the compile contents somewhere.
287-
288-
* `:client` - an atom representing the client to use. The default
289-
client is chosen on this order: `:cuda`, `:rocm`, `:tpu`, and `:host`.
290-
291-
* `:device_id` - the default device id to run the computation on.
292-
Defaults to the `:default_device_id` on the client
293-
283+
See `jit/2` for supported options.
294284
"""
295285
def compile(function, args, options \\ []) do
296286
Nx.Defn.compile(function, args, Keyword.put(options, :compiler, EXLA))

exla/lib/exla/defn.ex

Lines changed: 116 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ defmodule EXLA.Defn do
5252
comp_fun =
5353
&to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options)
5454

55-
{executable, used_inputs, {output, acc_output}, outfeed, input_typespecs} =
55+
{executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} =
5656
compile(
5757
client,
5858
key,
@@ -84,9 +84,7 @@ defmodule EXLA.Defn do
8484
EXLA.Defn.Lock.lock(run_key(executable))
8585
end)
8686

87-
if debug? do
88-
Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")
89-
end
87+
debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")
9088

9189
{time, streams} =
9290
:timer.tc(fn ->
@@ -131,9 +129,8 @@ defmodule EXLA.Defn do
131129
[stream]
132130
end)
133131

134-
if debug? do
132+
debug? &&
135133
Logger.debug("EXLA stream start on device #{executable.device_id} in #{us_to_ms(time)}ms")
136-
end
137134

138135
streams
139136
end
@@ -250,7 +247,7 @@ defmodule EXLA.Defn do
250247

251248
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))
252249

253-
{executable, used_inputs, outputs, outfeed, _input_typespecs?} =
250+
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
254251
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)
255252

256253
fn [args] ->
@@ -259,18 +256,15 @@ defmodule EXLA.Defn do
259256
EXLA.Defn.Lock.lock(run_key(executable))
260257
end)
261258

262-
if debug? do
263-
Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")
264-
end
259+
debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")
265260

266261
{time, res} =
267262
:timer.tc(fn ->
268263
maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options)
269264
end)
270265

271-
if debug? do
266+
debug? &&
272267
Logger.debug("EXLA execution on device #{executable.device_id} in #{us_to_ms(time)}ms")
273-
end
274268

275269
res
276270
end
@@ -360,15 +354,9 @@ defmodule EXLA.Defn do
360354
debug?,
361355
to_computation
362356
) do
363-
{{expr_cache_fun, comp_cache_fun}, options} =
364-
case Keyword.pop(options, :cache, true) do
365-
{true, options} ->
366-
Keyword.pop(options, EXLA, {&EXLA.Defn.LockedCache.run/2, &EXLA.Defn.LockedCache.run/2})
367-
368-
{false, options} ->
369-
cache_fun = fn _key, fun -> fun.() end
370-
{{cache_fun, cache_fun}, options}
371-
end
357+
{cache, options} = Keyword.pop(options, :cache, true)
358+
{hooks, options} = Keyword.pop(options, :hooks, %{})
359+
{lazy_transfers, options} = Keyword.pop(options, :lazy_transfers, :opt_in)
372360

373361
{args_key, reverse_args_identifiers} =
374362
Enum.map_reduce(vars, [], fn var, acc ->
@@ -381,119 +369,134 @@ defmodule EXLA.Defn do
381369
end)
382370
end)
383371

384-
{lazy_transfers, options} = Keyword.pop(options, :lazy_transfers, :opt_in)
372+
disk_key = %{
373+
client: client.name,
374+
args: args_key,
375+
lazy_transfers: lazy_transfers,
376+
hooks: Map.keys(hooks),
377+
options: options
378+
}
385379

386-
{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
387-
:timer.tc(fn ->
388-
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
389-
expr = fun.(vars)
390-
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
391-
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
380+
EXLA.Defn.Disk.cache(cache, client, disk_key, debug?, fn ->
381+
{{expr_cache_fun, comp_cache_fun}, options} =
382+
if cache do
383+
Keyword.pop(options, EXLA, {&EXLA.Defn.LockedCache.run/2, &EXLA.Defn.LockedCache.run/2})
384+
else
385+
cache_fun = fn _key, fun -> fun.() end
386+
{{cache_fun, cache_fun}, Keyword.delete(options, EXLA)}
387+
end
388+
389+
{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
390+
:timer.tc(fn ->
391+
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
392+
expr = fun.(vars)
393+
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
394+
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
395+
end)
392396
end)
393-
end)
394397

395-
if debug? do
396-
hit_or_miss = if expr, do: "miss", else: "hit"
398+
if debug? do
399+
hit_or_miss = if expr, do: "miss", else: "hit"
397400

398-
Logger.debug(
399-
"EXLA defn evaluation #{inspect(key)} cache #{hit_or_miss} in #{us_to_ms(eval_time)}ms"
400-
)
401-
end
401+
Logger.debug(
402+
"EXLA defn evaluation #{inspect(key)} cache #{hit_or_miss} in #{us_to_ms(eval_time)}ms"
403+
)
404+
end
402405

403-
{hooks, options} = Keyword.pop(options, :hooks, %{})
404-
outfeed = Outfeed.new(hooks, defined_hooks)
405-
comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options}
406+
outfeed = Outfeed.new(hooks, defined_hooks)
407+
comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options}
406408

407-
{comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} =
408-
:timer.tc(fn ->
409-
comp_cache_fun.(comp_key, fn ->
410-
{reverse_inputs_and_typespecs, reverse_infeeds} =
411-
reverse_args_identifiers
412-
|> Enum.reverse()
413-
|> EXLA.Defn.Buffers.split_by_value(used_inputs, fn
414-
{type, shape, _names}, i, nil -> {i, Typespec.tensor(type, shape)}
415-
{type, shape, _names}, i, depth -> {i, depth, Typespec.tensor(type, shape)}
416-
end)
409+
{comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} =
410+
:timer.tc(fn ->
411+
comp_cache_fun.(comp_key, fn ->
412+
{reverse_inputs_and_typespecs, reverse_infeeds} =
413+
reverse_args_identifiers
414+
|> Enum.reverse()
415+
|> EXLA.Defn.Buffers.split_by_value(used_inputs, fn
416+
{type, shape, _names}, i, nil -> {i, Typespec.tensor(type, shape)}
417+
{type, shape, _names}, i, depth -> {i, depth, Typespec.tensor(type, shape)}
418+
end)
417419

418-
inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs)
419-
420-
comp_typespecs =
421-
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
422-
423-
outputs =
424-
if stream? do
425-
# The computation returns the final accumulator value
426-
{_chunk_result, acc} = outputs
427-
acc
428-
else
429-
outputs
430-
end
431-
432-
out_typespecs =
433-
[outputs]
434-
|> Nx.Defn.Composite.flatten_list()
435-
|> Enum.map(fn t ->
436-
t
437-
|> Nx.devectorize()
438-
|> then(&Typespec.tensor(&1.type, &1.shape))
439-
end)
420+
inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs)
421+
422+
comp_typespecs =
423+
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
440424

441-
EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder ->
442-
# Only create the token when we know it will actually be
443-
# used, that is: streaming, lazy transfers or hooks
444-
outfeed =
445-
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
446-
outfeed
447-
|> Outfeed.with_token(Value.create_token(builder))
448-
|> Outfeed.add_infeeds(builder, reverse_infeeds)
425+
outputs =
426+
if stream? do
427+
# The computation returns the final accumulator value
428+
{_chunk_result, acc} = outputs
429+
acc
449430
else
450-
outfeed
431+
outputs
451432
end
452433

453-
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
454-
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)
455-
456-
{xla_time, executable} =
457-
:timer.tc(fn ->
458-
EXLA.MLIR.Module.compile(
459-
builder.module,
460-
client,
461-
comp_typespecs,
462-
builder.return_typespecs,
463-
options
464-
)
434+
out_typespecs =
435+
[outputs]
436+
|> Nx.Defn.Composite.flatten_list()
437+
|> Enum.map(fn t ->
438+
t
439+
|> Nx.devectorize()
440+
|> then(&Typespec.tensor(&1.type, &1.shape))
465441
end)
466442

467-
{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
443+
EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder ->
444+
# Only create the token when we know it will actually be
445+
# used, that is: streaming, lazy transfers or hooks
446+
outfeed =
447+
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
448+
outfeed
449+
|> Outfeed.with_token(Value.create_token(builder))
450+
|> Outfeed.add_infeeds(builder, reverse_infeeds)
451+
else
452+
outfeed
453+
end
454+
455+
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
456+
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)
457+
458+
{xla_time, executable} =
459+
:timer.tc(fn ->
460+
EXLA.MLIR.Module.compile(
461+
builder.module,
462+
client,
463+
comp_typespecs,
464+
builder.return_typespecs,
465+
options
466+
)
467+
end)
468+
469+
{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
470+
end)
468471
end)
469472
end)
470-
end)
471473

472-
cond do
473-
not debug? ->
474-
:ok
474+
cond do
475+
not debug? ->
476+
:ok
475477

476-
evaled ->
477-
Logger.debug(
478-
"EXLA compilation #{inspect(key)} cache miss in #{us_to_ms(comp_time)}ms (#{us_to_ms(xla_time)}ms in XLA)"
479-
)
478+
evaled ->
479+
Logger.debug(
480+
"EXLA compilation #{inspect(key)} cache miss in #{us_to_ms(comp_time)}ms (#{us_to_ms(xla_time)}ms in XLA)"
481+
)
480482

481-
true ->
482-
Logger.debug("EXLA compilation #{inspect(key)} cache hit in #{us_to_ms(comp_time)}ms")
483-
end
483+
true ->
484+
Logger.debug("EXLA compilation #{inspect(key)} cache hit in #{us_to_ms(comp_time)}ms")
485+
end
484486

485-
if expr || evaled do
486-
measurements = %{
487-
eval_time: eval_time,
488-
compile_time: comp_time,
489-
total_time: eval_time + comp_time
490-
}
487+
if expr || evaled do
488+
measurements = %{
489+
eval_time: eval_time,
490+
compile_time: comp_time,
491+
total_time: eval_time + comp_time
492+
}
491493

492-
:telemetry.execute([:exla, :compilation], measurements, %{key: key})
493-
end
494+
:telemetry.execute([:exla, :compilation], measurements, %{key: key})
495+
end
494496

495-
outfeed = Outfeed.with_user_hooks(outfeed, hooks)
496-
{executable, used_inputs, outputs, outfeed, inputs_and_typespecs}
497+
outfeed = Outfeed.with_user_hooks(outfeed, hooks)
498+
{executable, {used_inputs, outputs, outfeed, inputs_and_typespecs}}
499+
end)
497500
end
498501

499502
defp us_to_ms(time), do: Float.round(time / 1000, 1)

exla/lib/exla/defn/disk.ex

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
defmodule EXLA.Defn.Disk do
2+
@moduledoc false
3+
@version 1
4+
5+
require Logger
6+
7+
def cache(cache, _client, _key, _debug?, callback) when is_boolean(cache) do
8+
callback.()
9+
end
10+
11+
def cache(cache, client, keys, debug?, callback) when is_binary(cache) do
12+
cached =
13+
case File.read(cache) do
14+
{:ok, <<"EXLA", @version, blob::binary>>} ->
15+
case :erlang.binary_to_term(blob) do
16+
{^keys, executable, value} ->
17+
debug? && Logger.debug("EXLA disk cache found at #{inspect(cache)}")
18+
{EXLA.Executable.load(client, executable), value}
19+
20+
{disk_keys, _executable, _value} ->
21+
mismatched = for {key, value} <- disk_keys, keys[key] != value, do: key
22+
23+
Logger.warning("""
24+
EXLA disk cache does not match configuration.
25+
26+
Expected: #{inspect(Map.take(keys, mismatched))}
27+
28+
Found: #{inspect(Map.take(disk_keys, mismatched))}
29+
""")
30+
31+
nil
32+
end
33+
34+
{:ok, <<"EXLA", _::binary>>} ->
35+
Logger.warning(
36+
"Discarding EXLA disk cache at #{inspect(cache)} because it is for an older EXLA version"
37+
)
38+
39+
nil
40+
41+
{:error, _} ->
42+
debug? && Logger.debug("EXLA disk cache not found at #{inspect(cache)}")
43+
nil
44+
end
45+
46+
if cached do
47+
cached
48+
else
49+
{executable, value} = callback.()
50+
blob = :erlang.term_to_binary({keys, EXLA.Executable.dump(executable), value})
51+
File.mkdir_p!(Path.dirname(cache))
52+
File.write!(cache, <<"EXLA", @version, blob::binary>>)
53+
{executable, value}
54+
end
55+
end
56+
end

0 commit comments

Comments
 (0)