Skip to content

Add NIF for loading custom plugins #1519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <string>
#include <dlfcn.h>

#include "exla_client.h"
#include "exla_cuda.h"
Expand All @@ -11,11 +12,36 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"
#include "xla/service/custom_call_target_registry.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
// with calls to delete rather than just using the default destructor.

// We need to hold a reference to the `dlopen` handle for as long
// as EXLA is running, so we have this resource which holds the handle,
// then we define a custom free which calls `dlclose`. Then it's up to
// the caller to keep this resource in scope so it's not garbage collected
typedef struct {
void * handle;
} ExlaPlugin;

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]);

typedef struct {
const char* name;
ExlaCustomCallFunction func;
} ExlaPluginCustomCall;

static ErlNifResourceType* exla_plugin_resource_type;

void free_exla_plugin(ErlNifEnv* env, void* obj) {
ExlaPlugin* plugin = reinterpret_cast<ExlaPlugin*>(obj);
if (plugin != nullptr) {
dlclose(plugin->handle);
}
}

void free_exla_executable(ErlNifEnv* env, void* obj) {
exla::ExlaExecutable** executable = reinterpret_cast<exla::ExlaExecutable**>(obj);
if (*executable != nullptr) {
Expand Down Expand Up @@ -65,10 +91,17 @@ static int open_resources(ErlNifEnv* env) {
if (!exla::nif::open_resource<exla::MLIRModule*>(env, mod, "ExlaMLIRModule")) {
return -1;
}

if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
return -1;
}

// Just a C Resource
ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER);
exla_plugin_resource_type = enif_open_resource_type(env, mod, "ExlaPlugin", free_exla_plugin, flags, NULL);
if (!exla_plugin_resource_type) {
return -1;
}

return 1;
}

Expand Down Expand Up @@ -911,6 +944,48 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env);
}

// Plugins

ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

std::string library_path;

if (!exla::nif::get(env, argv[0], library_path)) {
return exla::nif::error(env, "Unable to get library path.");
}

void* handle = dlopen(library_path.c_str(), RTLD_NOW);
if (!handle) {
return exla::nif::error(env, "Unable to open library.");
}

const ExlaPluginCustomCall* custom_calls = (ExlaPluginCustomCall*) dlsym(handle, "exla_custom_calls");

if(!custom_calls) {
dlclose(handle);
return exla::nif::error(env, "Unable to find exla_custom_calls");
}

int i = 0;
ExlaPluginCustomCall func = custom_calls[i];
while (func.name != NULL) {
// TODO: GPU flags
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(func.name, func.func);
func = custom_calls[++i];
}

ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin));
plugin->handle = handle;

ERL_NIF_TERM result = enif_make_resource(env, plugin);
enif_release_resource(plugin);

return exla::nif::ok(env, result);
}

static ErlNifFunc exla_funcs[] = {
// MLIR Builder
{"mlir_new_context", 0, mlir_new_context},
Expand Down Expand Up @@ -947,6 +1022,9 @@ static ErlNifFunc exla_funcs[] = {
{"start_log_sink", 1, start_log_sink},
// Serialization
{"serialize_executable", 1, serialize_executable},
{"deserialize_executable", 2, deserialize_executable}};
{"deserialize_executable", 2, deserialize_executable},
// Plugins
{"load_custom_call_plugin_library", 1, load_custom_call_plugin_library}
};

ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL);
21 changes: 21 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,27 @@ defmodule EXLA.MLIR.Value do
{q, r}
end

def plugin_custom_call(registered_name, [%Value{function: func} | _] = args, result_typespec) do
operand_shapes =
Enum.map(args, fn %Value{function: ^func} = value ->
%{shape: op_shape} = get_typespec(value)
constant(func, Tuple.to_list(op_shape), Typespec.tensor({:s, 64}, {length(op_shape)}))
end)

operands =
args
|> Enum.zip_with(operand_shapes, fn val, shape -> [val, shape] end)
|> List.flatten()

# TODO: GPU
attributes = [
call_target_name: attr_string(registered_name),
backend_config: attr_string("Host")
]

op(func, "stablehlo.custom_call", operands, result_typespec, attributes: attributes)
end

def get_tuple_element(%Value{function: func} = operand, index, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [index: attr_i32(index)]
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,6 @@ defmodule EXLA.NIF do
def get_c_api_client(_device_type), do: :erlang.nif_error(:undef)

def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef)

def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef)
end
29 changes: 29 additions & 0 deletions exla/lib/exla/plugin.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
defmodule EXLA.Plugin do
@moduledoc """
Plugin system for registering custom calls.
"""

def register(library_path) do
unless File.exists?(library_path) do
raise ArgumentError, "#{library_path} does not exist"
end

case :persistent_term.get({__MODULE__, library_path}, nil) do
nil ->
ref =
library_path
|> EXLA.NIF.load_custom_call_plugin_library()
|> unwrap!()

# we need to keep the ref from getting garbage collected so
# we can use the symbols within it at anytime
:persistent_term.put({__MODULE__, library_path}, ref)

_ref ->
:ok
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a process instead such that, if someone does Application.stop(:exla) the process is shutdown as well as all plugins? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding how to store things, I would rather do an ETS than a process for storing the custom call registry if persistent term is not what we want, to keep close to the same read concurrency.

Another possible alternative would be a GenServer that manages the persistent term on terminate, it would clean up the persitent term state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this manager alternative, we'd have non-process functions that read first and write if needed to the persistent term, and the only purpose for the GenServer would be to ensure cleanup upon termination.

PS: Upon writing this I went reading and found https://hexdocs.pm/elixir/1.12/Application.html#c:prep_stop/1 which would serve this purpose nicely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prep_stop also works but you would need to iterate all persistent term to find the keys relevant to us. ETS would be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's something like EXLA.CustomCalls.cleanup/0 we could call, then that function will already know about all of the keys that should be cleaned up.

I see no issue with using ETS however, as the speed difference here only matters at defn compile time and not runtime

end

defp unwrap!({:ok, ref}), do: ref
defp unwrap!({:error, reason}), do: raise("#{reason}")
end
21 changes: 21 additions & 0 deletions exla/test/exla/plugin_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
defmodule EXLA.PluginTest do
use ExUnit.Case

describe "register/1" do
test "raises if file does not exist" do
assert_raise ArgumentError, ~r/does not exist/, fn ->
EXLA.Plugin.register("test/support/c/doesnotexist.so")
end
end

test "does not crash on invalid files" do
assert_raise RuntimeError, ~r/Unable to open/, fn ->
EXLA.Plugin.register(__ENV__.file)
end
end

test "registers a plugin" do
assert :ok = EXLA.Plugin.register("test/support/c/libcustom_plugin.so")
end
end
end
27 changes: 27 additions & 0 deletions exla/test/support/c/custom_plugin.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <cstdint>
#include <stddef.h>

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]);

typedef struct {
const char* name;
ExlaCustomCallFunction func;
} ExlaPluginCustomCall;

void custom_increment(void *out[], const void *in[]) {
int64_t *operand = (int64_t *)in[0];
int64_t *dim_sizes = (int64_t *)in[1];

int64_t *out_buffer = (int64_t *)out[0];

int64_t n = dim_sizes[0];

for (int64_t i = 0; i < n; i++) {
out_buffer[i] = operand[i] + 1;
}
}

extern "C" ExlaPluginCustomCall exla_custom_calls[] = {
{"custom_increment", custom_increment},
{NULL, NULL}
};
Binary file added exla/test/support/c/libcustom_plugin.so
Binary file not shown.
Loading