Skip to content

Commit 116b124

Browse files
authored
fix: use shared mlir context thread pool (#1534)
1 parent 8a9c2b3 commit 116b124

File tree

4 files changed

+52
-8
lines changed

4 files changed

+52
-8
lines changed

exla/c_src/exla/exla.cc

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "stablehlo/dialect/StablehloOps.h"
1212
#include "xla/pjrt/pjrt_api.h"
1313
#include "xla/service/platform_util.h"
14+
#include "llvm/Support/ThreadPool.h"
1415

1516
// All of these are created with calls to `new` and subsequently
1617
// passed to the VM as pointers-to-pointers so we balance it out
@@ -69,6 +70,10 @@ static int open_resources(ErlNifEnv* env) {
6970
if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
7071
return -1;
7172
}
73+
74+
if (!exla::nif::open_resource<llvm::StdThreadPool*>(env, mod, "TheadPool")) {
75+
return -1;
76+
}
7277
return 1;
7378
}
7479

@@ -150,12 +155,40 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
150155
return exla::nif::ok(env, exla::nif::make<exla::ExlaExecutable*>(env, executable));
151156
}
152157

158+
159+
ERL_NIF_TERM mlir_new_thread_pool(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
160+
if (argc != 1) {
161+
return exla::nif::error(env, "Bad argument count.");
162+
}
163+
164+
int concurrency;
165+
166+
if (!exla::nif::get(env, argv[0], &concurrency)) {
167+
return exla::nif::error(env, "Unable to get concurrency.");
168+
}
169+
170+
llvm::ThreadPoolStrategy strategy = llvm::hardware_concurrency(concurrency);
171+
llvm::StdThreadPool* pool = new llvm::StdThreadPool(strategy);
172+
173+
auto ret = exla::nif::make<llvm::StdThreadPool*>(env, pool);
174+
return exla::nif::ok(env, ret);
175+
}
176+
153177
ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
154-
if (argc != 0) {
178+
if (argc != 1) {
155179
return exla::nif::error(env, "Bad argument count.");
156180
}
157181

158-
mlir::MLIRContext* context = new mlir::MLIRContext();
182+
llvm::StdThreadPool** thread_pool;
183+
184+
if (!exla::nif::get<llvm::StdThreadPool*>(env, argv[0], thread_pool)) {
185+
return exla::nif::error(env, "Unable to get thread pool.");
186+
}
187+
188+
mlir::MLIRContext* context = new mlir::MLIRContext(mlir::MLIRContext::Threading::DISABLED);
189+
190+
auto interface_ptr = reinterpret_cast<llvm::ThreadPoolInterface*>(*thread_pool);
191+
context->setThreadPool(*interface_ptr);
159192
context->getOrLoadDialect<mlir::func::FuncDialect>();
160193
context->getOrLoadDialect<mlir::stablehlo::StablehloDialect>();
161194
context->getOrLoadDialect<mlir::mhlo::MhloDialect>();
@@ -909,7 +942,8 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
909942

910943
static ErlNifFunc exla_funcs[] = {
911944
// MLIR Builder
912-
{"mlir_new_context", 0, mlir_new_context},
945+
{"mlir_new_thread_pool", 1, mlir_new_thread_pool},
946+
{"mlir_new_context", 1, mlir_new_context},
913947
{"mlir_new_module", 1, mlir_new_module},
914948
{"mlir_create_function", 5, mlir_create_function},
915949
{"mlir_get_function_arguments", 1, mlir_get_function_arguments},

exla/lib/exla/application.ex

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ defmodule EXLA.Application do
1010
_ -> :os.set_signal(:sigchld, :default)
1111
end
1212

13+
pool_size = System.schedulers_online()
14+
1315
children = [
1416
EXLA.Logger,
1517
{NimblePool,
16-
worker: {EXLA.MLIR.ContextPool, :pool_state},
17-
pool_size: System.schedulers_online(),
18+
worker: {EXLA.MLIR.ContextPool, %{pool_size: pool_size}},
19+
pool_size: pool_size,
1820
name: EXLA.MLIR.ContextPool,
1921
lazy: true},
2022
EXLA.Client,

exla/lib/exla/mlir/context_pool.ex

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@ defmodule EXLA.MLIR.ContextPool do
1313
end
1414

1515
@impl NimblePool
16-
def init_worker(pool_state) do
17-
{:ok, context} = EXLA.NIF.mlir_new_context()
16+
def init_pool(%{pool_size: pool_size}) do
17+
{:ok, thread_pool} = EXLA.NIF.mlir_new_thread_pool(pool_size)
18+
19+
{:ok, %{thread_pool: thread_pool}}
20+
end
21+
22+
@impl NimblePool
23+
def init_worker(%{thread_pool: thread_pool} = pool_state) do
24+
{:ok, context} = EXLA.NIF.mlir_new_context(thread_pool)
1825
{:ok, context, pool_state}
1926
end
2027

exla/lib/exla/nif.ex

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ defmodule EXLA.NIF do
77
:erlang.load_nif(path, 0)
88
end
99

10-
def mlir_new_context, do: :erlang.nif_error(:undef)
10+
def mlir_new_thread_pool(_concurrency), do: :erlang.nif_error(:undef)
11+
def mlir_new_context(_thread_pool_ref), do: :erlang.nif_error(:undef)
1112

1213
def mlir_new_module(_context), do: :erlang.nif_error(:undef)
1314

0 commit comments

Comments
 (0)