|
11 | 11 | #include "stablehlo/dialect/StablehloOps.h"
|
12 | 12 | #include "xla/pjrt/pjrt_api.h"
|
13 | 13 | #include "xla/service/platform_util.h"
|
| 14 | +#include "llvm/Support/ThreadPool.h" |
14 | 15 |
|
15 | 16 | // All of these are created with calls to `new` and subsequently
|
16 | 17 | // passed to the VM as pointers-to-pointers so we balance it out
|
@@ -69,6 +70,10 @@ static int open_resources(ErlNifEnv* env) {
|
69 | 70 | if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
|
70 | 71 | return -1;
|
71 | 72 | }
|
| 73 | + |
| 74 | + if (!exla::nif::open_resource<llvm::StdThreadPool*>(env, mod, "TheadPool")) { |
| 75 | + return -1; |
| 76 | + } |
72 | 77 | return 1;
|
73 | 78 | }
|
74 | 79 |
|
@@ -150,12 +155,40 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
|
150 | 155 | return exla::nif::ok(env, exla::nif::make<exla::ExlaExecutable*>(env, executable));
|
151 | 156 | }
|
152 | 157 |
|
| 158 | + |
| 159 | +ERL_NIF_TERM mlir_new_thread_pool(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { |
| 160 | + if (argc != 1) { |
| 161 | + return exla::nif::error(env, "Bad argument count."); |
| 162 | + } |
| 163 | + |
| 164 | + int concurrency; |
| 165 | + |
| 166 | + if (!exla::nif::get(env, argv[0], &concurrency)) { |
| 167 | + return exla::nif::error(env, "Unable to get concurrency."); |
| 168 | + } |
| 169 | + |
| 170 | + llvm::ThreadPoolStrategy strategy = llvm::hardware_concurrency(concurrency); |
| 171 | + llvm::StdThreadPool* pool = new llvm::StdThreadPool(strategy); |
| 172 | + |
| 173 | + auto ret = exla::nif::make<llvm::StdThreadPool*>(env, pool); |
| 174 | + return exla::nif::ok(env, ret); |
| 175 | +} |
| 176 | + |
153 | 177 | ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
|
154 |
| - if (argc != 0) { |
| 178 | + if (argc != 1) { |
155 | 179 | return exla::nif::error(env, "Bad argument count.");
|
156 | 180 | }
|
157 | 181 |
|
158 |
| - mlir::MLIRContext* context = new mlir::MLIRContext(); |
| 182 | + llvm::StdThreadPool** thread_pool; |
| 183 | + |
| 184 | + if (!exla::nif::get<llvm::StdThreadPool*>(env, argv[0], thread_pool)) { |
| 185 | + return exla::nif::error(env, "Unable to get thread pool."); |
| 186 | + } |
| 187 | + |
| 188 | + mlir::MLIRContext* context = new mlir::MLIRContext(mlir::MLIRContext::Threading::DISABLED); |
| 189 | + |
| 190 | + auto interface_ptr = reinterpret_cast<llvm::ThreadPoolInterface*>(*thread_pool); |
| 191 | + context->setThreadPool(*interface_ptr); |
159 | 192 | context->getOrLoadDialect<mlir::func::FuncDialect>();
|
160 | 193 | context->getOrLoadDialect<mlir::stablehlo::StablehloDialect>();
|
161 | 194 | context->getOrLoadDialect<mlir::mhlo::MhloDialect>();
|
@@ -909,7 +942,8 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
|
909 | 942 |
|
910 | 943 | static ErlNifFunc exla_funcs[] = {
|
911 | 944 | // MLIR Builder
|
912 |
| - {"mlir_new_context", 0, mlir_new_context}, |
| 945 | + {"mlir_new_thread_pool", 1, mlir_new_thread_pool}, |
| 946 | + {"mlir_new_context", 1, mlir_new_context}, |
913 | 947 | {"mlir_new_module", 1, mlir_new_module},
|
914 | 948 | {"mlir_create_function", 5, mlir_create_function},
|
915 | 949 | {"mlir_get_function_arguments", 1, mlir_get_function_arguments},
|
|
0 commit comments