|
| 1 | +#include "elixir_callback_bridge.h" |
| 2 | + |
| 3 | +#include <cstring> |
| 4 | + |
| 5 | +namespace exla { |
| 6 | + |
| 7 | +namespace callback_bridge { |
| 8 | + |
| 9 | +struct BridgeState { |
| 10 | + ErlNifPid dispatcher_pid; |
| 11 | + bool dispatcher_set = false; |
| 12 | +}; |
| 13 | + |
| 14 | +BridgeState *GetBridgeState() { |
| 15 | + static BridgeState *state = new BridgeState(); |
| 16 | + return state; |
| 17 | +} |
| 18 | + |
| 19 | +fine::Ok<> start_elixir_callback_bridge(ErlNifEnv *env, |
| 20 | + ErlNifPid dispatcher_pid) { |
| 21 | + (void)env; |
| 22 | + auto state = GetBridgeState(); |
| 23 | + state->dispatcher_pid = dispatcher_pid; |
| 24 | + state->dispatcher_set = true; |
| 25 | + return fine::Ok(); |
| 26 | +} |
| 27 | + |
| 28 | +fine::Ok<> elixir_callback_reply(ErlNifEnv *env, |
| 29 | + fine::ResourcePtr<Pending> pending, |
| 30 | + fine::Atom status, fine::Term result) { |
| 31 | + deliver_reply(env, pending, status, result); |
| 32 | + return fine::Ok(); |
| 33 | +} |
| 34 | + |
| 35 | +fine::Ok<> clear_elixir_callback_bridge(ErlNifEnv *env, |
| 36 | + ErlNifPid dispatcher_pid) { |
| 37 | + (void)env; |
| 38 | + auto state = GetBridgeState(); |
| 39 | + |
| 40 | + if (state->dispatcher_set && |
| 41 | + std::memcmp(&state->dispatcher_pid, &dispatcher_pid, sizeof(ErlNifPid)) == |
| 42 | + 0) { |
| 43 | + state->dispatcher_set = false; |
| 44 | + } |
| 45 | + |
| 46 | + return fine::Ok(); |
| 47 | +} |
| 48 | + |
| 49 | +void deliver_reply(ErlNifEnv *env, fine::ResourcePtr<Pending> pending, |
| 50 | + fine::Atom status, fine::Term result_term) { |
| 51 | + Result cb_result; |
| 52 | + |
| 53 | + if (status == "ok") { |
| 54 | + // Successful reply: result_term is a list of binaries that we decode into |
| 55 | + // raw byte vectors via Fine and copy directly into the registered output |
| 56 | + // buffers. |
| 57 | + try { |
| 58 | + auto payloads = fine::decode<std::vector<ErlNifBinary>>(env, result_term); |
| 59 | + |
| 60 | + std::lock_guard<std::mutex> lock(pending->mu); |
| 61 | + |
| 62 | + if (payloads.size() != pending->outputs.size()) { |
| 63 | + cb_result.ok = false; |
| 64 | + cb_result.error = |
| 65 | + "mismatched number of callback outputs vs registered buffers"; |
| 66 | + } else { |
| 67 | + cb_result.ok = true; |
| 68 | + |
| 69 | + for (size_t i = 0; i < payloads.size(); ++i) { |
| 70 | + const ErlNifBinary &bytes = payloads[i]; |
| 71 | + auto &out_buf = pending->outputs[i]; |
| 72 | + |
| 73 | + if (bytes.size != out_buf.size) { |
| 74 | + cb_result.ok = false; |
| 75 | + cb_result.error = |
| 76 | + "callback returned binary of unexpected size for result buffer"; |
| 77 | + break; |
| 78 | + } |
| 79 | + |
| 80 | + if (out_buf.size > 0) { |
| 81 | + std::memcpy(out_buf.data, bytes.data, out_buf.size); |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + } catch (const std::exception &e) { |
| 86 | + cb_result.ok = false; |
| 87 | + cb_result.error = |
| 88 | + std::string("failed to decode Elixir callback outputs: ") + e.what(); |
| 89 | + } |
| 90 | + } else { |
| 91 | + // Error reply: result_term is expected to be {kind_atom, message :: binary} |
| 92 | + cb_result.ok = false; |
| 93 | + |
| 94 | + try { |
| 95 | + auto decoded = |
| 96 | + fine::decode<std::tuple<fine::Atom, ErlNifBinary>>(env, result_term); |
| 97 | + fine::Atom kind = std::get<0>(decoded); |
| 98 | + ErlNifBinary msg_bin = std::get<1>(decoded); |
| 99 | + |
| 100 | + cb_result.error = |
| 101 | + "elixir callback returned " + kind.to_string() + ": " + |
| 102 | + std::string(reinterpret_cast<const char *>(msg_bin.data), |
| 103 | + msg_bin.size); |
| 104 | + } catch (const std::exception &) { |
| 105 | + cb_result.error = "elixir callback returned error"; |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + { |
| 110 | + std::lock_guard<std::mutex> lock(pending->mu); |
| 111 | + pending->result = std::move(cb_result); |
| 112 | + pending->done = true; |
| 113 | + } |
| 114 | + |
| 115 | + pending->cv.notify_one(); |
| 116 | +} |
| 117 | + |
| 118 | +Result InvokeElixirCallback( |
| 119 | + xla::ffi::Span<const int64_t> callback_id_words, uint64_t callback_id_size, |
| 120 | + xla::ffi::Span<const int64_t> callback_server_pid_words, |
| 121 | + uint64_t callback_server_pid_size, const std::vector<Arg> &inputs, |
| 122 | + const std::vector<OutputBuffer> &outputs) { |
| 123 | + auto state = GetBridgeState(); |
| 124 | + |
| 125 | + if (!state->dispatcher_set) { |
| 126 | + Result res; |
| 127 | + res.ok = false; |
| 128 | + res.error = "EXLA elixir callback dispatcher is not set"; |
| 129 | + return res; |
| 130 | + } |
| 131 | + |
| 132 | + auto pending = fine::make_resource<Pending>(outputs); |
| 133 | + |
| 134 | + ErlNifEnv *msg_env = enif_alloc_env(); |
| 135 | + |
| 136 | + // Reinterpret the 64-bit words as a contiguous byte buffer and use the |
| 137 | + // original (unpadded) sizes when decoding the callback id and callback |
| 138 | + // server pid terms. |
| 139 | + if (callback_id_size > callback_id_words.size() * sizeof(int64_t)) { |
| 140 | + Result res; |
| 141 | + res.ok = false; |
| 142 | + res.error = "inconsistent callback id size"; |
| 143 | + return res; |
| 144 | + } |
| 145 | + |
| 146 | + if (callback_server_pid_size > |
| 147 | + callback_server_pid_words.size() * sizeof(int64_t)) { |
| 148 | + Result res; |
| 149 | + res.ok = false; |
| 150 | + res.error = "inconsistent callback server pid size"; |
| 151 | + return res; |
| 152 | + } |
| 153 | + |
| 154 | + const unsigned char *id_bytes = |
| 155 | + reinterpret_cast<const unsigned char *>(callback_id_words.begin()); |
| 156 | + |
| 157 | + ERL_NIF_TERM callback_id_term; |
| 158 | + if (!enif_binary_to_term(msg_env, id_bytes, callback_id_size, |
| 159 | + &callback_id_term, 0)) { |
| 160 | + Result res; |
| 161 | + res.ok = false; |
| 162 | + res.error = "failed to decode callback id term"; |
| 163 | + return res; |
| 164 | + } |
| 165 | + |
| 166 | + const unsigned char *pid_bytes = reinterpret_cast<const unsigned char *>( |
| 167 | + callback_server_pid_words.begin()); |
| 168 | + |
| 169 | + ERL_NIF_TERM callback_server_pid_term; |
| 170 | + if (!enif_binary_to_term(msg_env, pid_bytes, callback_server_pid_size, |
| 171 | + &callback_server_pid_term, 0)) { |
| 172 | + Result res; |
| 173 | + res.ok = false; |
| 174 | + res.error = "failed to decode callback server pid term"; |
| 175 | + return res; |
| 176 | + } |
| 177 | + |
| 178 | + ErlNifPid callback_server_pid; |
| 179 | + if (!enif_get_local_pid(msg_env, callback_server_pid_term, |
| 180 | + &callback_server_pid)) { |
| 181 | + Result res; |
| 182 | + res.ok = false; |
| 183 | + res.error = "failed to decode callback server pid"; |
| 184 | + return res; |
| 185 | + } |
| 186 | + |
| 187 | + // Encode arguments as [{bin, %EXLA.Typespec{}}, ...]. We currently send |
| 188 | + // plain binaries because the BEAM callback needs to own the data lifetime. |
| 189 | + std::vector<std::tuple<fine::Term, |
| 190 | + std::tuple<xla::ffi::DataType, std::vector<int64_t>>>> |
| 191 | + args_terms; |
| 192 | + args_terms.reserve(inputs.size()); |
| 193 | + |
| 194 | + for (const auto &tensor : inputs) { |
| 195 | + fine::Term bin_term = fine::make_new_binary( |
| 196 | + msg_env, reinterpret_cast<const char *>(tensor.data), |
| 197 | + tensor.size_bytes); |
| 198 | + |
| 199 | + // Build an %EXLA.Typespec{} directly from the ffi::DataType and dims via |
| 200 | + // Fine's encoder defined in exla_nif_util.h. |
| 201 | + auto arg_tuple = |
| 202 | + std::make_tuple(bin_term, std::make_tuple(tensor.dtype, tensor.dims)); |
| 203 | + |
| 204 | + args_terms.push_back(arg_tuple); |
| 205 | + } |
| 206 | + |
| 207 | + auto msg = std::make_tuple(fine::Atom("exla_elixir_call"), |
| 208 | + fine::Term(callback_id_term), args_terms, pending); |
| 209 | + |
| 210 | + // Use the dispatcher pid registered via start_elixir_callback_bridge/1. |
| 211 | + // We still are within the NIF thread that started the computation, |
| 212 | + // but we don't know its env, therefore we cannot use enif_whereis_pid. |
| 213 | + // enif_whereis_pid can be called with NULL, but only from non-ERTS |
| 214 | + // threads, and doing so here results in a segfault. |
| 215 | + enif_send(msg_env, &callback_server_pid, msg_env, fine::encode(msg_env, msg)); |
| 216 | + enif_free_env(msg_env); |
| 217 | + |
| 218 | + std::unique_lock<std::mutex> lock(pending->mu); |
| 219 | + pending->cv.wait(lock, [&pending] { return pending->done; }); |
| 220 | + |
| 221 | + return pending->result; |
| 222 | +} |
| 223 | + |
| 224 | +} // namespace callback_bridge |
| 225 | + |
| 226 | +} // namespace exla |
| 227 | + |
| 228 | + |
0 commit comments