Skip to content

Commit bf638fa

Browse files
feat: Nx.elixir_call/3 (#1627)
Co-authored-by: José Valim <[email protected]>
1 parent b9f3c95 commit bf638fa

24 files changed

+1488
-42
lines changed

exla/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
8484

8585
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc
8686
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
87-
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
87+
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/custom_calls/elixir_callback_bridge.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
8888
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
8989

9090

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "elixir_callback_bridge.h"
2+
3+
#include <cstring>
4+
#include <vector>
5+
#include <string>
6+
7+
#include "xla/ffi/api/ffi.h"
8+
#include "xla/ffi/ffi_api.h"
9+
10+
namespace ffi = xla::ffi;
11+
12+
namespace {
13+
14+
ffi::Error exla_elixir_callback_impl(
15+
ffi::RemainingArgs args, ffi::Span<const int64_t> callback_id_words,
16+
uint64_t callback_id_size,
17+
ffi::Span<const int64_t> callback_server_pid_words,
18+
uint64_t callback_server_pid_size, ffi::RemainingRets rets) {
19+
// Collect all input tensors into lightweight payload views.
20+
std::vector<exla::callback_bridge::Arg> inputs;
21+
inputs.reserve(args.size());
22+
23+
for (size_t i = 0; i < args.size(); ++i) {
24+
auto maybe_buf_or = args.get<ffi::AnyBuffer>(i);
25+
if (!maybe_buf_or) {
26+
return maybe_buf_or.error();
27+
}
28+
29+
ffi::AnyBuffer buf = *maybe_buf_or;
30+
31+
exla::callback_bridge::Arg tensor;
32+
tensor.dtype = buf.element_type();
33+
34+
auto dims = buf.dimensions();
35+
tensor.dims.assign(dims.begin(), dims.end());
36+
37+
tensor.data = reinterpret_cast<const uint8_t *>(buf.untyped_data());
38+
tensor.size_bytes = buf.size_bytes();
39+
40+
inputs.push_back(std::move(tensor));
41+
}
42+
43+
// Prepare output buffer descriptors so the callback bridge can write results
44+
// directly into the final destination buffers.
45+
std::vector<exla::callback_bridge::OutputBuffer> outputs;
46+
outputs.reserve(rets.size());
47+
48+
for (size_t i = 0; i < rets.size(); ++i) {
49+
auto maybe_ret_or = rets.get<ffi::AnyBuffer>(i);
50+
if (!maybe_ret_or) {
51+
return maybe_ret_or.error();
52+
}
53+
54+
ffi::Result<ffi::AnyBuffer> ret = *maybe_ret_or;
55+
ffi::AnyBuffer out = *ret;
56+
57+
exla::callback_bridge::OutputBuffer buf;
58+
buf.data = static_cast<uint8_t *>(out.untyped_data());
59+
buf.size = ffi::ByteWidth(out.element_type()) *
60+
static_cast<size_t>(out.element_count());
61+
62+
outputs.push_back(buf);
63+
}
64+
65+
// Call back into Elixir through the bridge. On success, the bridge writes
66+
// results directly into the provided output buffers.
67+
exla::callback_bridge::Result result =
68+
exla::callback_bridge::InvokeElixirCallback(
69+
callback_id_words, callback_id_size, callback_server_pid_words,
70+
callback_server_pid_size, inputs, outputs);
71+
72+
if (!result.ok) {
73+
return ffi::Error(ffi::ErrorCode::kInternal, result.error);
74+
}
75+
76+
return ffi::Error::Success();
77+
}
78+
79+
} // namespace
80+
81+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
82+
exla_elixir_callback, exla_elixir_callback_impl,
83+
ffi::Ffi::Bind()
84+
.RemainingArgs()
85+
.Attr<ffi::Span<const int64_t>>("callback_id")
86+
.Attr<uint64_t>("callback_id_size")
87+
.Attr<ffi::Span<const int64_t>>("callback_server_pid")
88+
.Attr<uint64_t>("callback_server_pid_size")
89+
.RemainingRets());
90+
91+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "exla_elixir_callback", "Host",
92+
exla_elixir_callback);
93+
94+
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#include "elixir_callback_bridge.h"
2+
3+
#include <cstring>
4+
5+
namespace exla {
6+
7+
namespace callback_bridge {
8+
9+
struct BridgeState {
10+
ErlNifPid dispatcher_pid;
11+
bool dispatcher_set = false;
12+
};
13+
14+
BridgeState *GetBridgeState() {
15+
static BridgeState *state = new BridgeState();
16+
return state;
17+
}
18+
19+
fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env,
20+
ErlNifPid dispatcher_pid) {
21+
(void)env;
22+
auto state = GetBridgeState();
23+
state->dispatcher_pid = dispatcher_pid;
24+
state->dispatcher_set = true;
25+
return fine::Ok();
26+
}
27+
28+
fine::Ok<> elixir_callback_reply(ErlNifEnv *env,
29+
fine::ResourcePtr<Pending> pending,
30+
fine::Atom status, fine::Term result) {
31+
deliver_reply(env, pending, status, result);
32+
return fine::Ok();
33+
}
34+
35+
fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env,
36+
ErlNifPid dispatcher_pid) {
37+
(void)env;
38+
auto state = GetBridgeState();
39+
40+
if (state->dispatcher_set &&
41+
std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) ==
42+
0) {
43+
state->dispatcher_set = false;
44+
}
45+
46+
return fine::Ok();
47+
}
48+
49+
void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending,
50+
fine::Atom status, fine::Term result_term) {
51+
Result cb_result;
52+
53+
if (status == "ok") {
54+
// Successful reply: result_term is a list of binaries that we decode into
55+
// raw byte vectors via Fine and copy directly into the registered output
56+
// buffers.
57+
try {
58+
auto payloads = fine::decode<std::vector<ErlNifBinary>>(env, result_term);
59+
60+
std::lock_guard<std::mutex> lock(pending->mu);
61+
62+
if (payloads.size() != pending->outputs.size()) {
63+
cb_result.ok = false;
64+
cb_result.error =
65+
"mismatched number of callback outputs vs registered buffers";
66+
} else {
67+
cb_result.ok = true;
68+
69+
for (size_t i = 0; i < payloads.size(); ++i) {
70+
const ErlNifBinary &bytes = payloads[i];
71+
auto &out_buf = pending->outputs[i];
72+
73+
if (bytes.size != out_buf.size) {
74+
cb_result.ok = false;
75+
cb_result.error =
76+
"callback returned binary of unexpected size for result buffer";
77+
break;
78+
}
79+
80+
if (out_buf.size > 0) {
81+
std::memcpy(out_buf.data, bytes.data, out_buf.size);
82+
}
83+
}
84+
}
85+
} catch (const std::exception &e) {
86+
cb_result.ok = false;
87+
cb_result.error =
88+
std::string("failed to decode Elixir callback outputs: ") + e.what();
89+
}
90+
} else {
91+
// Error reply: result_term is expected to be {kind_atom, message :: binary}
92+
cb_result.ok = false;
93+
94+
try {
95+
auto decoded =
96+
fine::decode<std::tuple<fine::Atom, ErlNifBinary>>(env, result_term);
97+
fine::Atom kind = std::get<0>(decoded);
98+
ErlNifBinary msg_bin = std::get<1>(decoded);
99+
100+
cb_result.error =
101+
"elixir callback returned " + kind.to_string() + ": " +
102+
std::string(reinterpret_cast<const char *>(msg_bin.data),
103+
msg_bin.size);
104+
} catch (const std::exception &) {
105+
cb_result.error = "elixir callback returned error";
106+
}
107+
}
108+
109+
{
110+
std::lock_guard<std::mutex> lock(pending->mu);
111+
pending->result = std::move(cb_result);
112+
pending->done = true;
113+
}
114+
115+
pending->cv.notify_one();
116+
}
117+
118+
Result InvokeElixirCallback(
119+
xla::ffi::Span<const int64_t> callback_id_words, uint64_t callback_id_size,
120+
xla::ffi::Span<const int64_t> callback_server_pid_words,
121+
uint64_t callback_server_pid_size, const std::vector<Arg> &inputs,
122+
const std::vector<OutputBuffer> &outputs) {
123+
auto state = GetBridgeState();
124+
125+
if (!state->dispatcher_set) {
126+
Result res;
127+
res.ok = false;
128+
res.error = "EXLA elixir callback dispatcher is not set";
129+
return res;
130+
}
131+
132+
auto pending = fine::make_resource<Pending>(outputs);
133+
134+
ErlNifEnv *msg_env = enif_alloc_env();
135+
136+
// Reinterpret the 64-bit words as a contiguous byte buffer and use the
137+
// original (unpadded) sizes when decoding the callback id and callback
138+
// server pid terms.
139+
if (callback_id_size > callback_id_words.size() * sizeof(int64_t)) {
140+
Result res;
141+
res.ok = false;
142+
res.error = "inconsistent callback id size";
143+
return res;
144+
}
145+
146+
if (callback_server_pid_size >
147+
callback_server_pid_words.size() * sizeof(int64_t)) {
148+
Result res;
149+
res.ok = false;
150+
res.error = "inconsistent callback server pid size";
151+
return res;
152+
}
153+
154+
const unsigned char *id_bytes =
155+
reinterpret_cast<const unsigned char *>(callback_id_words.begin());
156+
157+
ERL_NIF_TERM callback_id_term;
158+
if (!enif_binary_to_term(msg_env, id_bytes, callback_id_size,
159+
&callback_id_term, 0)) {
160+
Result res;
161+
res.ok = false;
162+
res.error = "failed to decode callback id term";
163+
return res;
164+
}
165+
166+
const unsigned char *pid_bytes = reinterpret_cast<const unsigned char *>(
167+
callback_server_pid_words.begin());
168+
169+
ERL_NIF_TERM callback_server_pid_term;
170+
if (!enif_binary_to_term(msg_env, pid_bytes, callback_server_pid_size,
171+
&callback_server_pid_term, 0)) {
172+
Result res;
173+
res.ok = false;
174+
res.error = "failed to decode callback server pid term";
175+
return res;
176+
}
177+
178+
ErlNifPid callback_server_pid;
179+
if (!enif_get_local_pid(msg_env, callback_server_pid_term,
180+
&callback_server_pid)) {
181+
Result res;
182+
res.ok = false;
183+
res.error = "failed to decode callback server pid";
184+
return res;
185+
}
186+
187+
// Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send
188+
// plain binaries because the BEAM callback needs to own the data lifetime.
189+
std::vector<std::tuple<fine::Term,
190+
std::tuple<xla::ffi::DataType, std::vector<int64_t>>>>
191+
args_terms;
192+
args_terms.reserve(inputs.size());
193+
194+
for (const auto &tensor : inputs) {
195+
fine::Term bin_term = fine::make_new_binary(
196+
msg_env, reinterpret_cast<const char *>(tensor.data),
197+
tensor.size_bytes);
198+
199+
// Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via
200+
// Fine's encoder defined in exla_nif_util.h.
201+
auto arg_tuple =
202+
std::make_tuple(bin_term, std::make_tuple(tensor.dtype, tensor.dims));
203+
204+
args_terms.push_back(arg_tuple);
205+
}
206+
207+
auto msg = std::make_tuple(fine::Atom("exla_elixir_call"),
208+
fine::Term(callback_id_term), args_terms, pending);
209+
210+
// Use the dispatcher pid registered via start_elixir_callback_bridge/1.
211+
// We still are within the NIF thread that started the computation,
212+
// but we don't know its env, therefore we cannot use enif_whereis_pid.
213+
// enif_whereis_pid can be called with NULL, but only from non-ERTS
214+
// threads, and doing so here results in a segfault.
215+
enif_send(msg_env, &callback_server_pid, msg_env, fine::encode(msg_env, msg));
216+
enif_free_env(msg_env);
217+
218+
std::unique_lock<std::mutex> lock(pending->mu);
219+
pending->cv.wait(lock, [&pending] { return pending->done; });
220+
221+
return pending->result;
222+
}
223+
224+
} // namespace callback_bridge
225+
226+
} // namespace exla
227+
228+

0 commit comments

Comments
 (0)