|
47 | 47 |
|
48 | 48 | #include "xla/mlir/utils/type_util.h" |
49 | 49 | #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" |
50 | | -#include "xla/pjrt/cpu/cpu_client.h" |
51 | | -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" |
52 | | -#include "xla/pjrt/pjrt_api.h" |
53 | | -#include "xla/pjrt/pjrt_c_api_client.h" |
54 | | -#include "xla/pjrt/pjrt_executable.h" |
55 | | -#include "xla/pjrt/status_casters.h" |
56 | 50 |
|
57 | 51 | #include "tsl/platform/init_main.h" |
58 | 52 | #include "tsl/profiler/lib/profiler_session.h" |
|
69 | 63 |
|
70 | 64 | #include "llvm-c/TargetMachine.h" |
71 | 65 |
|
| 66 | +// PJRT |
| 67 | +#include "xla/pjrt/cpu/cpu_client.h" |
| 68 | +#include "xla/pjrt/distributed/client.h" |
| 69 | +#include "xla/pjrt/distributed/distributed.h" |
| 70 | +#include "xla/pjrt/distributed/service.h" |
| 71 | +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" |
| 72 | +#include "xla/pjrt/pjrt_api.h" |
| 73 | +#include "xla/pjrt/pjrt_c_api_client.h" |
| 74 | +#include "xla/pjrt/pjrt_executable.h" |
| 75 | +#include "xla/pjrt/status_casters.h" |
| 76 | + |
72 | 77 | // shardy |
73 | 78 | #include "shardy/dialect/sdy/ir/dialect.h" |
74 | 79 | #include "shardy/integrations/c/attributes.h" |
75 | 80 | #include "xla/pjrt/mlir_to_hlo.h" |
76 | 81 |
|
77 | 82 | // IFRT |
78 | 83 | #include "xla/python/ifrt/array.h" |
| 84 | +#include "xla/python/ifrt/basic_device_list.h" |
79 | 85 | #include "xla/python/ifrt/client.h" |
80 | 86 | #include "xla/python/ifrt/compiler.h" |
81 | 87 | #include "xla/python/ifrt/device.h" |
82 | 88 | #include "xla/python/ifrt/device_list.h" |
83 | | -#include "xla/python/ifrt/basic_device_list.h" |
84 | 89 | #include "xla/python/ifrt/dtype.h" |
85 | 90 | #include "xla/python/ifrt/executable.h" |
86 | 91 | #include "xla/python/ifrt/hlo/hlo_program.h" |
@@ -129,6 +134,48 @@ void registerGenerateApplyPatternsPass(); |
129 | 134 | } // namespace enzyme |
130 | 135 | } // namespace mlir |
131 | 136 |
|
| 137 | +namespace reactant { |
| 138 | + |
| 139 | +template <typename T> struct unwrap_type { |
| 140 | + typedef T type; |
| 141 | +}; |
| 142 | +template <typename T> struct unwrap_type<std::shared_ptr<T>> { |
| 143 | + typedef T type; |
| 144 | +}; |
| 145 | +template <typename T> struct unwrap_type<tsl::RCReference<T>> { |
| 146 | + typedef T type; |
| 147 | +}; |
| 148 | + |
| 149 | +template <typename T> using unwrap_type_t = typename unwrap_type<T>::type; |
| 150 | + |
| 151 | +template <typename T> struct HeldValue { |
| 152 | +public: |
| 153 | + HeldValue(T &obj) : holded(obj) {} |
| 154 | + ~HeldValue() = default; |
| 155 | + |
| 156 | + unwrap_type_t<T> *ptr() const { return holded.get(); } |
| 157 | + |
| 158 | + T obj() const { return holded; } |
| 159 | + |
| 160 | + T value() const { return holded; } |
| 161 | + |
| 162 | + unwrap_type_t<T> *operator->() const { return ptr(); } |
| 163 | + |
| 164 | +private: |
| 165 | + T holded; |
| 166 | +}; |
| 167 | + |
| 168 | +template <typename T> HeldValue<T> *capture(T obj) { |
| 169 | + return new HeldValue<T>(obj); |
| 170 | +} |
| 171 | + |
| 172 | +} // namespace reactant |
| 173 | + |
| 174 | +using reactant::HeldValue; |
| 175 | +using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>; |
| 176 | +using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>; |
| 177 | +using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>; |
| 178 | + |
132 | 179 | extern "C" void (*ReactantThrowError)(const char *) = nullptr; |
133 | 180 |
|
134 | 181 | // Utilities for `StatusOr`. |
@@ -312,9 +359,23 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id, |
312 | 359 | extern "C" PjRtClient * |
313 | 360 | MakeGPUClient(int node_id, int num_nodes, int *allowed_devices, |
314 | 361 | int num_allowed_devices, double memory_fraction, bool preallocate, |
315 | | - const char *platform_name, const char **error) { |
| 362 | + const char *platform_name, const char **error, |
| 363 | + void *distributed_runtime_client) { |
316 | 364 | GpuClientOptions options; |
317 | | - // options.kv_store = "etcd"; |
| 365 | + |
| 366 | + if (num_nodes > 1) { |
| 367 | + if (distributed_runtime_client == nullptr) { |
| 368 | + *error = |
| 369 | + "`distributed_runtime_client` must be non-null if `num_nodes` > 1"; |
| 370 | + return nullptr; |
| 371 | + } |
| 372 | + auto typed_distributed_runtime_client = static_cast< |
| 373 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>( |
| 374 | + distributed_runtime_client); |
| 375 | + options.kv_store = GetDistributedKeyValueStore( |
| 376 | + typed_distributed_runtime_client->obj(), /*key_prefix=*/""); |
| 377 | + } |
| 378 | + |
318 | 379 | // options.allocator_config = |
319 | 380 | options.allocator_config.preallocate = preallocate; |
320 | 381 | options.allocator_config.memory_fraction = memory_fraction; |
@@ -429,6 +490,21 @@ extern "C" const char *DeviceGetKind(PjRtDevice *device) { |
429 | 490 | return cstr_from_string(device->device_kind()); |
430 | 491 | } |
431 | 492 |
|
| 493 | +extern "C" void ClientGetDevices(PjRtClient *client, PjRtDevice **out_devices) { |
| 494 | + auto span = client->devices(); |
| 495 | + for (int i = 0; i < span.size(); i++) { |
| 496 | + out_devices[i] = span[i]; |
| 497 | + } |
| 498 | +} |
| 499 | + |
| 500 | +extern "C" void ClientGetAddressableDevices(PjRtClient *client, |
| 501 | + PjRtDevice **out_devices) { |
| 502 | + auto span = client->addressable_devices(); |
| 503 | + for (int i = 0; i < span.size(); i++) { |
| 504 | + out_devices[i] = span[i]; |
| 505 | + } |
| 506 | +} |
| 507 | + |
432 | 508 | // To keep in sync with JLAllocatorStats in src/XLA.jl |
433 | 509 | struct JLAllocatorStats { |
434 | 510 | int64_t num_allocs; |
@@ -1196,62 +1272,19 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, |
1196 | 1272 | return wrap(entryFn); |
1197 | 1273 | } |
1198 | 1274 |
|
1199 | | -namespace reactant { |
1200 | | - |
1201 | | -template <typename T> struct unwrap_type { |
1202 | | - typedef T type; |
1203 | | -}; |
1204 | | -template <typename T> struct unwrap_type<std::shared_ptr<T>> { |
1205 | | - typedef T type; |
1206 | | -}; |
1207 | | -template <typename T> struct unwrap_type<tsl::RCReference<T>> { |
1208 | | - typedef T type; |
1209 | | -}; |
1210 | | - |
1211 | | -template <typename T> using unwrap_type_t = typename unwrap_type<T>::type; |
1212 | | - |
1213 | | -template <typename T> struct HeldValue { |
1214 | | -public: |
1215 | | - HeldValue(T &obj) : holded(obj) {} |
1216 | | - ~HeldValue() = default; |
1217 | | - |
1218 | | - unwrap_type_t<T> *ptr() const { return holded.get(); } |
1219 | | - |
1220 | | - T obj() const { return holded; } |
1221 | | - |
1222 | | - T value() const { return holded; } |
1223 | | - |
1224 | | - unwrap_type_t<T> *operator->() const { return ptr(); } |
1225 | | - |
1226 | | -private: |
1227 | | - T holded; |
1228 | | -}; |
1229 | | - |
1230 | | -template <typename T> HeldValue<T> *capture(T obj) { |
1231 | | - return new HeldValue<T>(obj); |
1232 | | -} |
1233 | | - |
1234 | | -} // namespace reactant |
1235 | | - |
1236 | | -using reactant::HeldValue; |
1237 | | -using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>; |
1238 | | -using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>; |
1239 | | -using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>; |
1240 | | - |
1241 | 1275 | extern "C" HeldPjRtClient * |
1242 | 1276 | pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) { |
1243 | 1277 | PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes); |
1244 | 1278 | return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1245 | 1279 | } |
1246 | 1280 |
|
1247 | | -extern "C" HeldPjRtClient * |
1248 | | -pjrt_make_gpu_client_shared(int node_id, int num_nodes, int *allowed_devices, |
1249 | | - int num_allowed_devices, double memory_fraction, |
1250 | | - bool preallocate, const char *platform_name, |
1251 | | - const char **error) { |
1252 | | - PjRtClient *client = |
1253 | | - MakeGPUClient(node_id, num_nodes, allowed_devices, num_allowed_devices, |
1254 | | - memory_fraction, preallocate, platform_name, error); |
| 1281 | +extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared( |
| 1282 | + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, |
| 1283 | + double memory_fraction, bool preallocate, const char *platform_name, |
| 1284 | + const char **error, void *distributed_runtime_client) { |
| 1285 | + PjRtClient *client = MakeGPUClient( |
| 1286 | + node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, |
| 1287 | + preallocate, platform_name, error, distributed_runtime_client); |
1255 | 1288 | return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1256 | 1289 | } |
1257 | 1290 |
|
@@ -1617,10 +1650,10 @@ extern "C" ifrt::Client * |
1617 | 1650 | ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices, |
1618 | 1651 | int num_allowed_devices, double memory_fraction, |
1619 | 1652 | bool preallocate, const char *platform_name, |
1620 | | - const char **error) { |
| 1653 | + const char **error, void *distributed_runtime_client) { |
1621 | 1654 | return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared( |
1622 | 1655 | node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, |
1623 | | - preallocate, platform_name, error)); |
| 1656 | + preallocate, platform_name, error, distributed_runtime_client)); |
1624 | 1657 | } |
1625 | 1658 |
|
1626 | 1659 | extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path, |
@@ -1815,3 +1848,75 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) { |
1815 | 1848 | } |
1816 | 1849 |
|
1817 | 1850 | #pragma endregion |
| 1851 | + |
| 1852 | +#pragma region PjRtDistributed |
| 1853 | + |
| 1854 | +extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> * |
| 1855 | +GetDistributedRuntimeClient(char *c_address, int32_t node_id, |
| 1856 | + int32_t rpc_timeout_in_seconds, |
| 1857 | + // int32_t init_timeout, |
| 1858 | + int32_t shutdown_timeout_in_minutes, |
| 1859 | + int32_t heartbeat_interval_in_seconds, |
| 1860 | + int max_missing_heartbeats, bool use_compression) { |
| 1861 | + xla::DistributedRuntimeClient::Options options; |
| 1862 | + options.node_id = node_id; |
| 1863 | + options.rpc_timeout = absl::Seconds(rpc_timeout_in_seconds); |
| 1864 | + // options.init_timeout = absl::Seconds(init_timeout); |
| 1865 | + options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes); |
| 1866 | + options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds); |
| 1867 | + options.max_missing_heartbeats = max_missing_heartbeats; |
| 1868 | + |
| 1869 | + std::string address = c_address; |
| 1870 | + |
| 1871 | + return reactant::capture( |
| 1872 | + xla::GetDistributedRuntimeClient(address, options, use_compression)); |
| 1873 | +} |
| 1874 | + |
| 1875 | +extern "C" void free_distributed_runtime_client( |
| 1876 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) { |
| 1877 | + delete client; |
| 1878 | +} |
| 1879 | + |
| 1880 | +extern "C" void distributed_runtime_client_connect( |
| 1881 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) { |
| 1882 | + auto status = client->obj()->Connect(); |
| 1883 | + if (!status.ok()) |
| 1884 | + ReactantThrowError(status.ToString().c_str()); |
| 1885 | +} |
| 1886 | + |
| 1887 | +extern "C" void distributed_runtime_client_shutdown( |
| 1888 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *client) { |
| 1889 | + auto status = client->obj()->Shutdown(); |
| 1890 | + if (!status.ok()) |
| 1891 | + ReactantThrowError(status.ToString().c_str()); |
| 1892 | +} |
| 1893 | + |
| 1894 | +extern "C" xla::DistributedRuntimeService *GetDistributedRuntimeService( |
| 1895 | + char *c_address, int num_nodes, int32_t heartbeat_interval_in_seconds, |
| 1896 | + int max_missing_heartbeats, int32_t cluster_register_timeout_in_minutes, |
| 1897 | + int32_t shutdown_timeout_in_minutes) { |
| 1898 | + xla::CoordinationServiceImpl::Options options; |
| 1899 | + options.num_nodes = num_nodes; |
| 1900 | + options.heartbeat_interval = absl::Seconds(heartbeat_interval_in_seconds); |
| 1901 | + options.max_missing_heartbeats = max_missing_heartbeats; |
| 1902 | + options.cluster_register_timeout = |
| 1903 | + absl::Minutes(cluster_register_timeout_in_minutes); |
| 1904 | + options.shutdown_timeout = absl::Minutes(shutdown_timeout_in_minutes); |
| 1905 | + |
| 1906 | + std::string address = c_address; |
| 1907 | + |
| 1908 | + return MyValueOrThrow(xla::GetDistributedRuntimeService(address, options)) |
| 1909 | + .release(); |
| 1910 | +} |
| 1911 | + |
| 1912 | +extern "C" void free_distributed_runtime_service( |
| 1913 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeService>> *service) { |
| 1914 | + delete service; |
| 1915 | +} |
| 1916 | + |
| 1917 | +extern "C" void distributed_runtime_service_shutdown( |
| 1918 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeService>> *service) { |
| 1919 | + service->obj()->Shutdown(); |
| 1920 | +} |
| 1921 | + |
| 1922 | +#pragma endregion |
0 commit comments