Skip to content

Commit ab82611

Browse files
committed
Allow jitted functions to work across nodes
1 parent fb2d4b3 commit ab82611

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

exla/lib/exla/executable.ex

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@ defmodule EXLA.Executable do
1111

1212
@doc """
1313
Runs the given executable with a list of lists as inputs and the given options.
14+
15+
Works across nodes.
1416
"""
15-
def run(%Executable{} = executable, [subinputs | _] = inputs, options \\ [])
17+
def run(executable, inputs, options \\ [])
18+
19+
def run(%Executable{ref: ref, client: client} = executable, inputs, options)
20+
when node(ref) != node() do
21+
client
22+
|> load(dump(executable))
23+
|> run(inputs, options)
24+
end
25+
26+
def run(%Executable{} = executable, [subinputs | _] = inputs, options)
1627
when is_list(subinputs) do
1728
%{client: client, device_id: device_id, output_typespecs: output_typespecs, ref: ref} =
1829
executable
@@ -25,17 +36,20 @@ defmodule EXLA.Executable do
2536
@doc """
2637
Dumps the executable to a data structure that can be serialized
2738
with `term_to_binary`.
39+
40+
Works across nodes.
2841
"""
2942
# If you change this function, you must bump the version in EXLA.Defn.Disk.
3043
def dump(%Executable{
31-
ref: executable,
44+
ref: ref,
3245
output_typespecs: output_typespecs,
3346
num_replicas: num_replicas,
3447
num_partitions: num_partitions,
3548
device_id: device_id
36-
}) do
49+
})
50+
when node(ref) == node() do
3751
serialized_exec =
38-
executable
52+
ref
3953
|> EXLA.NIF.serialize_executable()
4054
|> unwrap!()
4155
|> IO.iodata_to_binary()
@@ -49,6 +63,10 @@ defmodule EXLA.Executable do
4963
}
5064
end
5165

66+
def dump(%Executable{ref: ref} = executable) do
67+
:erpc.call(node(ref), __MODULE__, :dump, [executable])
68+
end
69+
5270
@doc """
5371
Loads a previously dumped executable.
5472
"""

0 commit comments

Comments
 (0)