11defmodule EXLA.CallbackServer do
22 @ moduledoc """
3- Dispatcher and registry for `Nx.elixir_call /3` callbacks used by EXLA.
3+ Dispatcher and registry for `Nx.runtime_call /3` callbacks used by EXLA.
44
55 This server has two responsibilities:
66
@@ -10,16 +10,16 @@ defmodule EXLA.CallbackServer do
1010
1111 The native side is expected to:
1212
13- * Lower `:elixir_call ` nodes to a CPU-only host `CustomCall` named
14- `"exla_elixir_callback "` with a callback id encoded in its attributes.
13+ * Lower `:runtime_call ` nodes to a CPU-only host `CustomCall` named
14+ `"exla_runtime_callback "` with a callback id encoded in its attributes.
1515
1616 * Run a bridge thread that sends messages of the form:
1717
18- {:exla_elixir_call , callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()}
18+ {:exla_runtime_call , callback_id :: term(), args :: [Nx.Tensor.t()], reply_tag :: term()}
1919
2020 to this process and waits on a native future associated with `reply_tag`.
2121
22- * Provide a NIF `EXLA.NIF.elixir_callback_reply /2` that completes the
22+ * Provide a NIF `EXLA.NIF.runtime_callback_reply /2` that completes the
2323 native future when we send the reply back.
2424 """
2525
@@ -42,7 +42,7 @@ defmodule EXLA.CallbackServer do
4242 Starts the callback server and registers it as the EXLA dispatcher process.
4343
4444 The EXLA NIF is notified of the dispatcher PID so it can route
45- `:exla_elixir_call ` messages to this process.
45+ `:exla_runtime_call ` messages to this process.
4646 """
4747 def start_link ( _init_arg ) do
4848 GenServer . start_link ( __MODULE__ , :ok )
@@ -51,7 +51,7 @@ defmodule EXLA.CallbackServer do
5151 @ doc """
5252 Registers a callback function, its output template, argument template, and options.
5353
54- The `id` is typically the underlying `Nx.Defn.Expr` id of the `:elixir_call `
54+ The `id` is typically the underlying `Nx.Defn.Expr` id of the `:runtime_call `
5555 node, which the EXLA compiler also encodes into the host `CustomCall` so the
5656 native side can reference the right callback.
5757 """
@@ -69,15 +69,15 @@ defmodule EXLA.CallbackServer do
6969 @ impl true
7070 def init ( :ok ) do
7171 # Inform native side that this process is the dispatcher for elixir callbacks
72- _ = EXLA.NIF . start_elixir_callback_bridge ( self ( ) )
72+ _ = EXLA.NIF . start_runtime_callback_bridge ( self ( ) )
7373
7474 { :ok , % __MODULE__ { } }
7575 end
7676
7777 @ impl true
7878 def terminate ( _reason , _state ) do
7979 try do
80- EXLA.NIF . clear_elixir_callback_bridge ( self ( ) )
80+ EXLA.NIF . clear_runtime_callback_bridge ( self ( ) )
8181 rescue
8282 _ -> :ok
8383 end
@@ -94,7 +94,7 @@ defmodule EXLA.CallbackServer do
9494 end
9595
9696 @ impl true
97- def handle_info ( { :exla_elixir_call , callback_id , args_spec , reply_tag } , % __MODULE__ { } = state ) do
97+ def handle_info ( { :exla_runtime_call , callback_id , args_spec , reply_tag } , % __MODULE__ { } = state ) do
9898 reply_payload =
9999 try do
100100 case Map . fetch ( state . callbacks , callback_id ) do
@@ -121,7 +121,7 @@ defmodule EXLA.CallbackServer do
121121 { :noreply , state }
122122 end
123123
124- def handle_info ( :exla_elixir_call_executable_dropped , state ) do
124+ def handle_info ( :exla_runtime_call_executable_dropped , state ) do
125125 { :stop , :normal , state }
126126 end
127127
@@ -173,7 +173,7 @@ defmodule EXLA.CallbackServer do
173173 # Shape mismatch between callback result and output template.
174174 defp encode_reply ( { :error , { :shape_mismatch , left , right } } ) do
175175 msg =
176- "expected the elixir_call function to match the given output template " <>
176+ "expected the runtime_call function to match the given output template " <>
177177 "#{ inspect ( right ) } , got: #{ inspect ( left ) } "
178178
179179 { :error , { :argument_error , msg } }
@@ -182,7 +182,7 @@ defmodule EXLA.CallbackServer do
182182 # Callback returned something that isn't a tensor/tuple matching the template.
183183 defp encode_reply ( { :error , { :invalid_result , left , right } } ) do
184184 msg =
185- "expected the elixir_call function to return a value compatible with the output " <>
185+ "expected the runtime_call function to return a value compatible with the output " <>
186186 "template #{ inspect ( right ) } , got: #{ inspect ( left ) } "
187187
188188 { :error , { :argument_error , msg } }
@@ -202,7 +202,7 @@ defmodule EXLA.CallbackServer do
202202
203203 # Unknown callback id from native.
204204 defp encode_reply ( { :error , :unknown_callback } ) do
205- msg = "unknown EXLA elixir_call callback id"
205+ msg = "unknown EXLA runtime_call callback id"
206206 { :error , { :runtime_error , msg } }
207207 end
208208
@@ -257,7 +257,7 @@ defmodule EXLA.CallbackServer do
257257
258258 defp send_reply ( reply_tag , { status , result } ) do
259259 try do
260- EXLA.NIF . elixir_callback_reply ( reply_tag , status , result )
260+ EXLA.NIF . runtime_callback_reply ( reply_tag , status , result )
261261 rescue
262262 _ ->
263263 Logger . error (
0 commit comments