Skip to content

Commit 9ea612f

Browse files
committed
feat: JLL changes for #780
1 parent ea9c9e9 commit 9ea612f

File tree

2 files changed

+178
-61
lines changed

2 files changed

+178
-61
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 166 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@
4747

4848
#include "xla/mlir/utils/type_util.h"
4949
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
50-
#include "xla/pjrt/cpu/cpu_client.h"
51-
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
52-
#include "xla/pjrt/pjrt_api.h"
53-
#include "xla/pjrt/pjrt_c_api_client.h"
54-
#include "xla/pjrt/pjrt_executable.h"
55-
#include "xla/pjrt/status_casters.h"
5650

5751
#include "tsl/platform/init_main.h"
5852
#include "tsl/profiler/lib/profiler_session.h"
@@ -69,18 +63,29 @@
6963

7064
#include "llvm-c/TargetMachine.h"
7165

66+
// PJRT
67+
#include "xla/pjrt/cpu/cpu_client.h"
68+
#include "xla/pjrt/distributed/client.h"
69+
#include "xla/pjrt/distributed/distributed.h"
70+
#include "xla/pjrt/distributed/service.h"
71+
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
72+
#include "xla/pjrt/pjrt_api.h"
73+
#include "xla/pjrt/pjrt_c_api_client.h"
74+
#include "xla/pjrt/pjrt_executable.h"
75+
#include "xla/pjrt/status_casters.h"
76+
7277
// shardy
7378
#include "shardy/dialect/sdy/ir/dialect.h"
7479
#include "shardy/integrations/c/attributes.h"
7580
#include "xla/pjrt/mlir_to_hlo.h"
7681

7782
// IFRT
7883
#include "xla/python/ifrt/array.h"
84+
#include "xla/python/ifrt/basic_device_list.h"
7985
#include "xla/python/ifrt/client.h"
8086
#include "xla/python/ifrt/compiler.h"
8187
#include "xla/python/ifrt/device.h"
8288
#include "xla/python/ifrt/device_list.h"
83-
#include "xla/python/ifrt/basic_device_list.h"
8489
#include "xla/python/ifrt/dtype.h"
8590
#include "xla/python/ifrt/executable.h"
8691
#include "xla/python/ifrt/hlo/hlo_program.h"
@@ -129,6 +134,48 @@ void registerGenerateApplyPatternsPass();
129134
} // namespace enzyme
130135
} // namespace mlir
131136

137+
namespace reactant {
138+
139+
template <typename T> struct unwrap_type {
140+
typedef T type;
141+
};
142+
template <typename T> struct unwrap_type<std::shared_ptr<T>> {
143+
typedef T type;
144+
};
145+
template <typename T> struct unwrap_type<tsl::RCReference<T>> {
146+
typedef T type;
147+
};
148+
149+
template <typename T> using unwrap_type_t = typename unwrap_type<T>::type;
150+
151+
template <typename T> struct HeldValue {
152+
public:
153+
HeldValue(T &obj) : holded(obj) {}
154+
~HeldValue() = default;
155+
156+
unwrap_type_t<T> *ptr() const { return holded.get(); }
157+
158+
T obj() const { return holded; }
159+
160+
T value() const { return holded; }
161+
162+
unwrap_type_t<T> *operator->() const { return ptr(); }
163+
164+
private:
165+
T holded;
166+
};
167+
168+
template <typename T> HeldValue<T> *capture(T obj) {
169+
return new HeldValue<T>(obj);
170+
}
171+
172+
} // namespace reactant
173+
174+
using reactant::HeldValue;
175+
using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>;
176+
using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
177+
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;
178+
132179
extern "C" void (*ReactantThrowError)(const char *) = nullptr;
133180

134181
// Utilities for `StatusOr`.
@@ -312,9 +359,23 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
312359
extern "C" PjRtClient *
313360
MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
314361
int num_allowed_devices, double memory_fraction, bool preallocate,
315-
const char *platform_name, const char **error) {
362+
const char *platform_name, const char **error,
363+
void *distributed_runtime_client) {
316364
GpuClientOptions options;
317-
// options.kv_store = "etcd";
365+
366+
if (num_nodes > 1) {
367+
if (distributed_runtime_client == nullptr) {
368+
*error =
369+
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
370+
return nullptr;
371+
}
372+
auto typed_distributed_runtime_client = static_cast<
373+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
374+
distributed_runtime_client);
375+
options.kv_store = GetDistributedKeyValueStore(
376+
typed_distributed_runtime_client->obj(), /*key_prefix=*/"");
377+
}
378+
318379
// options.allocator_config =
319380
options.allocator_config.preallocate = preallocate;
320381
options.allocator_config.memory_fraction = memory_fraction;
@@ -429,6 +490,21 @@ extern "C" const char *DeviceGetKind(PjRtDevice *device) {
429490
return cstr_from_string(device->device_kind());
430491
}
431492

493+
extern "C" void ClientGetDevices(PjRtClient *client, PjRtDevice **out_devices) {
494+
auto span = client->devices();
495+
for (int i = 0; i < span.size(); i++) {
496+
out_devices[i] = span[i];
497+
}
498+
}
499+
500+
extern "C" void ClientGetAddressableDevices(PjRtClient *client,
501+
PjRtDevice **out_devices) {
502+
auto span = client->addressable_devices();
503+
for (int i = 0; i < span.size(); i++) {
504+
out_devices[i] = span[i];
505+
}
506+
}
507+
432508
// To keep in sync with JLAllocatorStats in src/XLA.jl
433509
struct JLAllocatorStats {
434510
int64_t num_allocs;
@@ -1196,62 +1272,19 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
11961272
return wrap(entryFn);
11971273
}
11981274

1199-
namespace reactant {
1200-
1201-
template <typename T> struct unwrap_type {
1202-
typedef T type;
1203-
};
1204-
template <typename T> struct unwrap_type<std::shared_ptr<T>> {
1205-
typedef T type;
1206-
};
1207-
template <typename T> struct unwrap_type<tsl::RCReference<T>> {
1208-
typedef T type;
1209-
};
1210-
1211-
template <typename T> using unwrap_type_t = typename unwrap_type<T>::type;
1212-
1213-
template <typename T> struct HeldValue {
1214-
public:
1215-
HeldValue(T &obj) : holded(obj) {}
1216-
~HeldValue() = default;
1217-
1218-
unwrap_type_t<T> *ptr() const { return holded.get(); }
1219-
1220-
T obj() const { return holded; }
1221-
1222-
T value() const { return holded; }
1223-
1224-
unwrap_type_t<T> *operator->() const { return ptr(); }
1225-
1226-
private:
1227-
T holded;
1228-
};
1229-
1230-
template <typename T> HeldValue<T> *capture(T obj) {
1231-
return new HeldValue<T>(obj);
1232-
}
1233-
1234-
} // namespace reactant
1235-
1236-
using reactant::HeldValue;
1237-
using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>;
1238-
using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
1239-
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;
1240-
12411275
extern "C" HeldPjRtClient *
12421276
pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) {
12431277
PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes);
12441278
return reactant::capture(std::shared_ptr<PjRtClient>(client));
12451279
}
12461280

1247-
extern "C" HeldPjRtClient *
1248-
pjrt_make_gpu_client_shared(int node_id, int num_nodes, int *allowed_devices,
1249-
int num_allowed_devices, double memory_fraction,
1250-
bool preallocate, const char *platform_name,
1251-
const char **error) {
1252-
PjRtClient *client =
1253-
MakeGPUClient(node_id, num_nodes, allowed_devices, num_allowed_devices,
1254-
memory_fraction, preallocate, platform_name, error);
1281+
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1282+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1283+
double memory_fraction, bool preallocate, const char *platform_name,
1284+
const char **error, void *distributed_runtime_client) {
1285+
PjRtClient *client = MakeGPUClient(
1286+
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1287+
preallocate, platform_name, error, distributed_runtime_client);
12551288
return reactant::capture(std::shared_ptr<PjRtClient>(client));
12561289
}
12571290

@@ -1617,10 +1650,10 @@ extern "C" ifrt::Client *
16171650
ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices,
16181651
int num_allowed_devices, double memory_fraction,
16191652
bool preallocate, const char *platform_name,
1620-
const char **error) {
1653+
const char **error, void *distributed_runtime_client) {
16211654
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
16221655
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1623-
preallocate, platform_name, error));
1656+
preallocate, platform_name, error, distributed_runtime_client));
16241657
}
16251658

16261659
extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path,
@@ -1815,3 +1848,75 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
18151848
}
18161849

18171850
#pragma endregion
1851+
1852+
#pragma region PjRtDistributed
1853+
1854+
extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
1855+
GetDistributedRuntimeClient(char *c_address, int32_t node_id,
1856+
int32_t rpc_timeout_in_seconds,
1857+
// int32_t init_timeout,
1858+
int32_t shutdown_timeout_in_minutes,
1859+
int32_t heartbeat_interval_in_seconds,
1860+
int max_missing_heartbeats, bool use_compression) {
1861+
xla::DistributedRuntimeClient::Options options;
1862+
options.node_id = node_id;
1863+
options.rpc_timeout = absl::Seconds(rpc_timeout_in_seconds);
1864+
// options.init_timeout = absl::Seconds(init_timeout);
1865+
options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes);
1866+
options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds);
1867+
options.max_missing_heartbeats = max_missing_heartbeats;
1868+
1869+
std::string address = c_address;
1870+
1871+
return reactant::capture(
1872+
xla::GetDistributedRuntimeClient(address, options, use_compression));
1873+
}
1874+
1875+
extern "C" void free_distributed_runtime_client(
1876+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) {
1877+
delete client;
1878+
}
1879+
1880+
extern "C" void distributed_runtime_client_connect(
1881+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) {
1882+
auto status = client->obj()->Connect();
1883+
if (!status.ok())
1884+
ReactantThrowError(status.ToString().c_str());
1885+
}
1886+
1887+
extern "C" void distributed_runtime_client_shutdown(
1888+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) {
1889+
auto status = client->obj()->Shutdown();
1890+
if (!status.ok())
1891+
ReactantThrowError(status.ToString().c_str());
1892+
}
1893+
1894+
extern "C" xla::DistributedRuntimeService *GetDistributedRuntimeService(
1895+
char *c_address, int num_nodes, int32_t heartbeat_interval_in_seconds,
1896+
int max_missing_heartbeats, int32_t cluster_register_timeout_in_minutes,
1897+
int32_t shutdown_timeout_in_minutes) {
1898+
xla::CoordinationServiceImpl::Options options;
1899+
options.num_nodes = num_nodes;
1900+
options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds);
1901+
options.max_missing_heartbeats = max_missing_heartbeats;
1902+
options.cluster_register_timeout =
1903+
absl::Minutes(cluster_register_timeout_in_minutes);
1904+
options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes);
1905+
1906+
std::string address = c_address;
1907+
1908+
return MyValueOrThrow(xla::GetDistributedRuntimeService(address, options))
1909+
.release();
1910+
}
1911+
1912+
extern "C" void free_distributed_runtime_service(
1913+
HeldValue<std::shared_ptr<xla::DistributedRuntimeService>> *service) {
1914+
delete service;
1915+
}
1916+
1917+
extern "C" void distributed_runtime_service_shutdown(
1918+
HeldValue<std::shared_ptr<xla::DistributedRuntimeService>> *service) {
1919+
service->obj()->Shutdown();
1920+
}
1921+
1922+
#pragma endregion

deps/ReactantExtra/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,15 @@ cc_library(
481481
"-Wl,-exported_symbol,_hlo_sharding_to_op_sharding",
482482
"-Wl,-exported_symbol,_hlo_sharding_to_string",
483483
"-Wl,-exported_symbol,_DeviceGetKind",
484+
"-Wl,-exported_symbol,_GetDistributedRuntimeClient",
485+
"-Wl,-exported_symbol,_free_distributed_runtime_client",
486+
"-Wl,-exported_symbol,_distributed_runtime_client_connect",
487+
"-Wl,-exported_symbol,_distributed_runtime_client_shutdown",
488+
"-Wl,-exported_symbol,_GetDistributedRuntimeService",
489+
"-Wl,-exported_symbol,_free_distributed_runtime_service",
490+
"-Wl,-exported_symbol,_distributed_runtime_service_shutdown",
491+
"-Wl,-exported_symbol,_ClientGetDevices",
492+
"-Wl,-exported_symbol,_ClientGetAddressableDevices",
484493
]}),
485494
deps = [
486495
"@enzyme//:EnzymeMLIR",
@@ -526,6 +535,9 @@ cc_library(
526535
"@xla//xla/pjrt:pjrt_api",
527536
"@xla//xla/pjrt:pjrt_c_api_client",
528537
"@xla//xla/pjrt/cpu:cpu_client",
538+
"@xla//xla/pjrt/distributed:distributed",
539+
"@xla//xla/pjrt/distributed:client",
540+
"@xla//xla/pjrt/distributed:service",
529541

530542
"@xla//xla:xla_proto_cc",
531543
"@xla//xla:xla_proto_cc_impl",

0 commit comments

Comments
 (0)