Skip to content

Commit 6246a4a

Browse files
authored
refactor: elixir_call -> runtime_call (#1645)
1 parent bf638fa commit 6246a4a

File tree

17 files changed

+117
-106
lines changed

17 files changed

+117
-106
lines changed

exla/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
8484

8585
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc
8686
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
87-
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/elixir_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
87+
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/runtime_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
8888
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
8989

9090

exla/c_src/exla/custom_calls/elixir_callback.cc renamed to exla/c_src/exla/custom_calls/runtime_callback.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "elixir_callback_bridge.h"
1+
#include "runtime_callback_bridge.h"
22

33
#include <cstring>
44
#include <vector>
@@ -11,7 +11,7 @@ namespace ffi = xla::ffi;
1111

1212
namespace {
1313

14-
ffi::Error exla_elixir_callback_impl(
14+
ffi::Error exla_runtime_callback_impl(
1515
ffi::RemainingArgs args, ffi::Span<const int64_t> callback_id_words,
1616
uint64_t callback_id_size,
1717
ffi::Span<const int64_t> callback_server_pid_words,
@@ -65,7 +65,7 @@ ffi::Error exla_elixir_callback_impl(
6565
// Call back into Elixir through the bridge. On success, the bridge writes
6666
// results directly into the provided output buffers.
6767
exla::callback_bridge::Result result =
68-
exla::callback_bridge::InvokeElixirCallback(
68+
exla::callback_bridge::InvokeRuntimeCallback(
6969
callback_id_words, callback_id_size, callback_server_pid_words,
7070
callback_server_pid_size, inputs, outputs);
7171

@@ -79,7 +79,7 @@ ffi::Error exla_elixir_callback_impl(
7979
} // namespace
8080

8181
XLA_FFI_DEFINE_HANDLER_SYMBOL(
82-
exla_elixir_callback, exla_elixir_callback_impl,
82+
exla_runtime_callback, exla_runtime_callback_impl,
8383
ffi::Ffi::Bind()
8484
.RemainingArgs()
8585
.Attr<ffi::Span<const int64_t>>("callback_id")
@@ -88,7 +88,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
8888
.Attr<uint64_t>("callback_server_pid_size")
8989
.RemainingRets());
9090

91-
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host",
92-
exla_elixir_callback);
93-
94-
91+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_runtime_callback", "Host",
92+
exla_runtime_callback);

exla/c_src/exla/custom_calls/elixir_callback_bridge.cc renamed to exla/c_src/exla/custom_calls/runtime_callback_bridge.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "elixir_callback_bridge.h"
1+
#include "runtime_callback_bridge.h"
22

33
#include <cstring>
44

@@ -16,24 +16,24 @@ BridgeState *GetBridgeState() {
1616
return state;
1717
}
1818

19-
fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env,
20-
ErlNifPid dispatcher_pid) {
19+
fine::Ok<> start_runtime_callback_bridge(ErlNifEnv *env,
20+
ErlNifPid dispatcher_pid) {
2121
(void)env;
2222
auto state = GetBridgeState();
2323
state->dispatcher_pid = dispatcher_pid;
2424
state->dispatcher_set = true;
2525
return fine::Ok();
2626
}
2727

28-
fine::Ok<> elixir_callback_reply(ErlNifEnv *env,
29-
fine::ResourcePtr<Pending> pending,
30-
fine::Atom status, fine::Term result) {
28+
fine::Ok<> runtime_callback_reply(ErlNifEnv *env,
29+
fine::ResourcePtr<Pending> pending,
30+
fine::Atom status, fine::Term result) {
3131
deliver_reply(env, pending, status, result);
3232
return fine::Ok();
3333
}
3434

35-
fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env,
36-
ErlNifPid dispatcher_pid) {
35+
fine::Ok<> clear_runtime_callback_bridge(ErlNifEnv *env,
36+
ErlNifPid dispatcher_pid) {
3737
(void)env;
3838
auto state = GetBridgeState();
3939

@@ -115,7 +115,7 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending,
115115
pending->cv.notify_one();
116116
}
117117

118-
Result InvokeElixirCallback(
118+
Result InvokeRuntimeCallback(
119119
xla::ffi::Span<const int64_t> callback_id_words, uint64_t callback_id_size,
120120
xla::ffi::Span<const int64_t> callback_server_pid_words,
121121
uint64_t callback_server_pid_size, const std::vector<Arg> &inputs,
@@ -204,10 +204,10 @@ Result InvokeElixirCallback(
204204
args_terms.push_back(arg_tuple);
205205
}
206206

207-
auto msg = std::make_tuple(fine::Atom("exla_elixir_call"),
207+
auto msg = std::make_tuple(fine::Atom("exla_runtime_call"),
208208
fine::Term(callback_id_term), args_terms, pending);
209209

210-
// Use the dispatcher pid registered via start_elixir_callback_bridge/1.
210+
// Use the dispatcher pid registered via start_runtime_callback_bridge/1.
211211
// We still are within the NIF thread that started the computation,
212212
// but we don't know its env, therefore we cannot use enif_whereis_pid.
213213
// enif_whereis_pid can be called with NULL, but only from non-ERTS

exla/c_src/exla/custom_calls/elixir_callback_bridge.h renamed to exla/c_src/exla/custom_calls/runtime_callback_bridge.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct Pending {
5656
// Called from the Elixir side to deliver a reply for a given pending handle.
5757
// We receive the reply as a status atom (e.g. :ok or :error) and a result
5858
// term. For the :ok case the result is a list of binaries that we decode as
59-
// ElixirCallbackTensor outputs via Fine's decoding machinery.
59+
// RuntimeCallbackTensor outputs via Fine's decoding machinery.
6060
void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending,
6161
fine::Atom status, fine::Term result);
6262

@@ -70,21 +70,21 @@ void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending,
7070
//
7171
// It returns a Result that either indicates success (data has
7272
// been written into the registered output buffers) or an error message.
73-
Result InvokeElixirCallback(
73+
Result InvokeRuntimeCallback(
7474
xla::ffi::Span<const int64_t> callback_id_words, uint64_t callback_id_size,
7575
xla::ffi::Span<const int64_t> callback_server_pid_words,
7676
uint64_t callback_server_pid_size, const std::vector<Arg> &inputs,
7777
const std::vector<OutputBuffer> &outputs);
7878

79-
fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env,
80-
ErlNifPid dispatcher_pid);
79+
fine::Ok<> start_runtime_callback_bridge(ErlNifEnv *env,
80+
ErlNifPid dispatcher_pid);
8181

82-
fine::Ok<> elixir_callback_reply(ErlNifEnv *env,
83-
fine::ResourcePtr<Pending> pending,
84-
fine::Atom status, fine::Term result);
82+
fine::Ok<> runtime_callback_reply(ErlNifEnv *env,
83+
fine::ResourcePtr<Pending> pending,
84+
fine::Atom status, fine::Term result);
8585

86-
fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env,
87-
ErlNifPid dispatcher_pid);
86+
fine::Ok<> clear_runtime_callback_bridge(ErlNifEnv *env,
87+
ErlNifPid dispatcher_pid);
8888

8989
} // namespace callback_bridge
9090

exla/c_src/exla/exla.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <tuple>
77
#include <unordered_map>
88

9-
#include "custom_calls/elixir_callback_bridge.h"
9+
#include "custom_calls/runtime_callback_bridge.h"
1010
#include "exla_client.h"
1111
#include "exla_cuda.h"
1212
#include "exla_log_sink.h"
@@ -544,13 +544,13 @@ FINE_NIF(get_per_device_memory, 0);
544544

545545
// Elixir callback bridge NIF registrations
546546

547-
using callback_bridge::clear_elixir_callback_bridge;
548-
using callback_bridge::elixir_callback_reply;
549-
using callback_bridge::start_elixir_callback_bridge;
547+
using callback_bridge::clear_runtime_callback_bridge;
548+
using callback_bridge::runtime_callback_reply;
549+
using callback_bridge::start_runtime_callback_bridge;
550550

551-
FINE_NIF(start_elixir_callback_bridge, 0);
552-
FINE_NIF(elixir_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND);
553-
FINE_NIF(clear_elixir_callback_bridge, 0);
551+
FINE_NIF(start_runtime_callback_bridge, 0);
552+
FINE_NIF(runtime_callback_reply, ERL_NIF_DIRTY_JOB_IO_BOUND);
553+
FINE_NIF(clear_runtime_callback_bridge, 0);
554554

555555
// Logging
556556

exla/c_src/exla/exla_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ ExlaExecutable::~ExlaExecutable() {
108108
// Notify the callback server that this executable has been dropped so it
109109
// can clean up any associated state.
110110
ERL_NIF_TERM msg =
111-
fine::encode(env, fine::Atom("exla_elixir_call_executable_dropped"));
111+
fine::encode(env, fine::Atom("exla_runtime_call_executable_dropped"));
112112
enif_send(nullptr, &callback_server_pid_.value(), env, msg);
113113
enif_free_env(env);
114114
}

exla/lib/exla/callback_server.ex

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defmodule EXLA.CallbackServer do
22
@moduledoc """
3-
Dispatcher and registry for `Nx.elixir_call/3` callbacks used by EXLA.
3+
Dispatcher and registry for `Nx.runtime_call/3` callbacks used by EXLA.
44
55
This server has two responsibilities:
66
@@ -10,16 +10,16 @@ defmodule EXLA.CallbackServer do
1010
1111
The native side is expected to:
1212
13-
* Lower `:elixir_call` nodes to a CPU-only host `CustomCall` named
14-
`"exla_elixir_callback"` with a callback id encoded in its attributes.
13+
* Lower `:runtime_call` nodes to a CPU-only host `CustomCall` named
14+
`"exla_runtime_callback"` with a callback id encoded in its attributes.
1515
1616
* Run a bridge thread that sends messages of the form:
1717
18-
{:exla_elixir_call, callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()}
18+
{:exla_runtime_call, callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()}
1919
2020
to this process and waits on a native future associated with `reply_tag`.
2121
22-
* Provide a NIF `EXLA.NIF.elixir_callback_reply/2` that completes the
22+
* Provide a NIF `EXLA.NIF.runtime_callback_reply/2` that completes the
2323
native future when we send the reply back.
2424
"""
2525

@@ -42,7 +42,7 @@ defmodule EXLA.CallbackServer do
4242
Starts the callback server and registers it as the EXLA dispatcher process.
4343
4444
The EXLA NIF is notified of the dispatcher PID so it can route
45-
`:exla_elixir_call` messages to this process.
45+
`:exla_runtime_call` messages to this process.
4646
"""
4747
def start_link(_init_arg) do
4848
GenServer.start_link(__MODULE__, :ok)
@@ -51,7 +51,7 @@ defmodule EXLA.CallbackServer do
5151
@doc """
5252
Registers a callback function, its output template, argument template, and options.
5353
54-
The `id` is typically the underlying `Nx.Defn.Expr` id of the `:elixir_call`
54+
The `id` is typically the underlying `Nx.Defn.Expr` id of the `:runtime_call`
5555
node, which the EXLA compiler also encodes into the host `CustomCall` so the
5656
native side can reference the right callback.
5757
"""
@@ -69,15 +69,15 @@ defmodule EXLA.CallbackServer do
6969
@impl true
7070
def init(:ok) do
7171
# Inform native side that this process is the dispatcher for elixir callbacks
72-
_ = EXLA.NIF.start_elixir_callback_bridge(self())
72+
_ = EXLA.NIF.start_runtime_callback_bridge(self())
7373

7474
{:ok, %__MODULE__{}}
7575
end
7676

7777
@impl true
7878
def terminate(_reason, _state) do
7979
try do
80-
EXLA.NIF.clear_elixir_callback_bridge(self())
80+
EXLA.NIF.clear_runtime_callback_bridge(self())
8181
rescue
8282
_ -> :ok
8383
end
@@ -94,7 +94,7 @@ defmodule EXLA.CallbackServer do
9494
end
9595

9696
@impl true
97-
def handle_info({:exla_elixir_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do
97+
def handle_info({:exla_runtime_call, callback_id, args_spec, reply_tag}, %__MODULE__{} = state) do
9898
reply_payload =
9999
try do
100100
case Map.fetch(state.callbacks, callback_id) do
@@ -121,7 +121,7 @@ defmodule EXLA.CallbackServer do
121121
{:noreply, state}
122122
end
123123

124-
def handle_info(:exla_elixir_call_executable_dropped, state) do
124+
def handle_info(:exla_runtime_call_executable_dropped, state) do
125125
{:stop, :normal, state}
126126
end
127127

@@ -173,7 +173,7 @@ defmodule EXLA.CallbackServer do
173173
# Shape mismatch between callback result and output template.
174174
defp encode_reply({:error, {:shape_mismatch, left, right}}) do
175175
msg =
176-
"expected the elixir_call function to match the given output template " <>
176+
"expected the runtime_call function to match the given output template " <>
177177
"#{inspect(right)}, got: #{inspect(left)}"
178178

179179
{:error, {:argument_error, msg}}
@@ -182,7 +182,7 @@ defmodule EXLA.CallbackServer do
182182
# Callback returned something that isn't a tensor/tuple matching the template.
183183
defp encode_reply({:error, {:invalid_result, left, right}}) do
184184
msg =
185-
"expected the elixir_call function to return a value compatible with the output " <>
185+
"expected the runtime_call function to return a value compatible with the output " <>
186186
"template #{inspect(right)}, got: #{inspect(left)}"
187187

188188
{:error, {:argument_error, msg}}
@@ -202,7 +202,7 @@ defmodule EXLA.CallbackServer do
202202

203203
# Unknown callback id from native.
204204
defp encode_reply({:error, :unknown_callback}) do
205-
msg = "unknown EXLA elixir_call callback id"
205+
msg = "unknown EXLA runtime_call callback id"
206206
{:error, {:runtime_error, msg}}
207207
end
208208

@@ -257,7 +257,7 @@ defmodule EXLA.CallbackServer do
257257

258258
defp send_reply(reply_tag, {status, result}) do
259259
try do
260-
EXLA.NIF.elixir_callback_reply(reply_tag, status, result)
260+
EXLA.NIF.runtime_callback_reply(reply_tag, status, result)
261261
rescue
262262
_ ->
263263
Logger.error(

exla/lib/exla/defn.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ defmodule EXLA.Defn do
582582
end
583583

584584
defp cached_recur_operator(
585-
:elixir_call,
585+
:runtime_call,
586586
%T{data: %Expr{id: id, args: [tensor_expr, opts, fun, out_template]}} = expr,
587587
%{client: %EXLA.Client{platform: :host}, callback_server_pid: callback_server_pid} =
588588
state,
@@ -608,19 +608,19 @@ defmodule EXLA.Defn do
608608
typespecs = container_to_typespecs(out_template)
609609

610610
results =
611-
Value.elixir_call(arg_values, typespecs, callback_server_pid, id)
611+
Value.runtime_call(arg_values, typespecs, callback_server_pid, id)
612612

613613
{wrap_tuple_result(results, expr), cache}
614614
end
615615

616616
defp cached_recur_operator(
617-
:elixir_call,
617+
:runtime_call,
618618
_expr,
619619
%{client: %EXLA.Client{platform: platform}},
620620
_cache
621621
) do
622622
raise """
623-
Nx.elixir_call/3 is currently only supported for EXLA CPU (platform: :host),
623+
Nx.runtime_call/3 is currently only supported for EXLA CPU (platform: :host),
624624
but the active EXLA client is configured for platform #{inspect(platform)}.
625625
Please run on the :host client or wait for future segmentation-based support.
626626
"""

exla/lib/exla/mlir/value.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,11 @@ defmodule EXLA.MLIR.Value do
836836
Builds a StableHLO `custom_call` that targets the EXLA Elixir callback bridge.
837837
838838
The `callback_id` is typically the underlying `Nx.Defn.Expr` id of the
839-
`:elixir_call` node. It is encoded as a binary (via `:erlang.term_to_binary/1`)
839+
`:runtime_call` node. It is encoded as a binary (via `:erlang.term_to_binary/1`)
840840
and then represented as a list of 64-bit words in the custom call attributes,
841841
similar to how we encode the callback server PID.
842842
"""
843-
def elixir_call(
843+
def runtime_call(
844844
[%Value{function: func} | _] = operands,
845845
typespecs,
846846
callback_server_pid,
@@ -855,7 +855,7 @@ defmodule EXLA.MLIR.Value do
855855
term_to_int64_list(callback_id)
856856

857857
attributes = [
858-
call_target_name: attr_string("exla_elixir_callback"),
858+
call_target_name: attr_string("exla_runtime_callback"),
859859
# api_version 4 enables the typed FFI API used by our callback handler.
860860
api_version: attr_i32(4),
861861
backend_config:

exla/lib/exla/nif.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ defmodule EXLA.NIF do
8080
def get_per_device_memory(_client), do: err!()
8181

8282
# Elixir callback bridge
83-
def start_elixir_callback_bridge(_dispatcher_pid), do: err!()
84-
def clear_elixir_callback_bridge(_dispatcher_pid), do: err!()
85-
def elixir_callback_reply(_reply_tag, _status, _result), do: err!()
83+
def start_runtime_callback_bridge(_dispatcher_pid), do: err!()
84+
def clear_runtime_callback_bridge(_dispatcher_pid), do: err!()
85+
def runtime_callback_reply(_reply_tag, _status, _result), do: err!()
8686

8787
defp err!(), do: :erlang.nif_error(:undef)
8888
end

0 commit comments

Comments
 (0)