diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f270e8ef23..4aa9d436db 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -71,6 +71,7 @@ // shardy #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/integrations/c/attributes.h" +#include "xla/pjrt/mlir_to_hlo.h" // IFRT #include "xla/python/ifrt/array.h" @@ -160,6 +161,19 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, // TODO mlirComplexAttrGetnValue // TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return // wrap(complex::NumberAttr::getTypeID()); } + +extern "C" void ReactantFuncSetResultAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setResultAttr(pos, unwrap(name), unwrap(attr)); +} + +extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setArgAttr(pos, unwrap(name), unwrap(attr)); +} + #pragma endregion // auxiliar functions @@ -438,11 +452,27 @@ extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) { return Buffer->client(); } +extern "C" absl::Span BufferShape(PjRtBuffer *Buffer) { + return Buffer->dimensions(); +} + +extern "C" int64_t BufferNDimensions(PjRtBuffer *Buffer) { + return Buffer->dimensions().length(); +} + +extern "C" xla::PrimitiveType BufferPrimitiveType(PjRtBuffer *Buffer) { + return Buffer->element_type(); +} + +extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; } + extern "C" PjRtClient *DeviceToClient(PjRtDevice *Device) { return Device->client(); } -extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; } +extern "C" PjRtClient *PjRtLoadedExecutableGetClient(PjRtLoadedExecutable *exec) { + return exec->client(); +} // https://openxla.org/xla/shapes // This minor-to-major dimension order of 0 up to N-1 is akin to column-major @@ -593,33 +623,60 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { return wrap(res); } -/* Note that this */ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, MlirModule cmod, - int *global_ordinals, - int num_global_ordinals, + int64_t device_id, + bool is_sharded, + // const int64_t *mesh_shape, + // int64_t num_mesh_shape, + const int64_t *mesh_ids, + int64_t num_mesh_ids, const char* xla_gpu_cuda_data_dir) { auto program = std::make_unique(cast(*unwrap(cmod))); CompileOptions options; + options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir); - // https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601 - int device_count = client->addressable_device_count(); + auto cmodop = cast(*unwrap(cmod)); - options.executable_build_options.set_num_replicas(device_count); - options.executable_build_options.set_num_partitions(1); - options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir); + if (is_sharded) { + assert(device_id < 0); - xla::DeviceAssignment device_assignment(device_count, 1); - for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) { - int ordinal = global_ordinals[device_id]; - if (ordinal < 0) { - continue; + options.executable_build_options.set_num_replicas(1); + options.executable_build_options.set_num_partitions(num_mesh_ids); + + options.executable_build_options.set_use_spmd_partitioning(true); + options.executable_build_options.set_use_shardy_partitioner(true); + + // auto partitioning for GPUs is not available in open source version of XLA + // options.executable_build_options.set_use_auto_spmd_partitioning(true); + // std::vector mesh_shape_vec(mesh_shape, mesh_shape + num_mesh_shape); + // options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec); + // std::vector mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids); + // options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec); + + xla::DeviceAssignment device_assignment(1, num_mesh_ids); + for (int64_t i = 0; i < num_mesh_ids; ++i) { + int64_t mesh_id = mesh_ids[i]; + assert(mesh_id >= 0); + device_assignment(0, mesh_id) = i; } - device_assignment(ordinal, 0) = device_id; + options.executable_build_options.set_device_assignment(device_assignment); + + // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460 + xla::ExportShardyForHloRoundTrip(cmodop); + } else { + assert(device_id >= 0); + + options.executable_build_options.set_num_replicas(1); + options.executable_build_options.set_num_partitions(1); + options.executable_build_options.set_device_ordinal(device_id); + + xla::DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device_id; + options.executable_build_options.set_device_assignment(device_assignment); } - options.executable_build_options.set_device_assignment(device_assignment); auto addressable_devices = client->addressable_devices(); if (!addressable_devices.empty()) { @@ -633,8 +690,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client, options.executable_build_options.set_device_memory_size(*stats->bytes_limit); } } - auto exec = - MyValueOrThrow(client->Compile(cast(*unwrap(cmod)), options)); + auto exec = MyValueOrThrow(client->Compile(cmodop, options)); return exec.release(); } @@ -694,23 +750,33 @@ extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args, } } -extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, - PjRtBuffer **op_args, uint8_t *is_arg_donatable, +extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, + PjRtBuffer **op_args, + const int64_t *mesh_ids, int64_t num_mesh_ids, + uint8_t *is_arg_donatable, int num_results, PjRtBuffer **op_results, uint8_t *futures, FutureType **future_results) { auto client = exec->client(); - int num_devices = client->addressable_device_count(); - // Ensure argument_handles is structured as num_devices x num_args - std::vector> argument_handles(num_devices); + // Ensure argument_handles is structured as num_mesh_ids x num_args + std::vector> argument_handles(num_mesh_ids); + int num_args = op_args_len / num_mesh_ids; // Distribute arguments across devices - for (int device_idx = 0; device_idx < num_devices; ++device_idx) { - argument_handles[device_idx].reserve(num_args); + for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { + int64_t mesh_id = mesh_ids[device_idx]; + + // Validate mesh_id + if (mesh_id < 0 || mesh_id >= num_mesh_ids) { + ReactantThrowError(("Invalid mesh_id " + std::to_string(mesh_id) + " at device_idx " + + std::to_string(device_idx)).c_str()); + } + + argument_handles[mesh_id].reserve(num_args); for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { // Assuming op_args is a flat array of size num_devices * num_args // where arguments for each device are contiguous - argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]); + argument_handles[mesh_id].push_back(op_args[mesh_id * num_args + arg_idx]); } } @@ -722,41 +788,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args, } options.untuple_result = true; - std::optional> returned_futures; + std::optional> returned_futures = std::vector(); auto results = MyValueOrThrow( exec->Execute(static_cast>>( argument_handles), options, returned_futures)); - assert(results.size() == num_devices); + assert(results.size() == num_mesh_ids); - for (int device_idx = 0; device_idx < num_devices; ++device_idx) { - if (results[device_idx].size() != num_results) { - llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size() + for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { + int64_t mesh_id = mesh_ids[device_idx]; + if (results[mesh_id].size() != num_results) { + llvm::errs() << " results[" << mesh_id << "].size()=" << results[mesh_id].size() << " num_results=" << num_results << "\n"; } - assert(results[device_idx].size() == num_results); + assert(results[mesh_id].size() == num_results); } // Handle returned futures - if (returned_futures) { + if (returned_futures.has_value()) { *futures = true; - assert(returned_futures->size() == num_devices * num_results); - for (int device_idx = 0; device_idx < num_devices; ++device_idx) { - for (int result_idx = 0; result_idx < num_results; ++result_idx) { - int flat_index = device_idx * num_results + result_idx; - future_results[flat_index] = new FutureType((*returned_futures)[flat_index]); - } - } + assert(returned_futures->size() == num_mesh_ids); } else { *futures = false; } // Copy results into the output buffers - for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { + int64_t mesh_id = mesh_ids[device_idx]; for (int result_idx = 0; result_idx < num_results; ++result_idx) { - int flat_index = device_idx * num_results + result_idx; - op_results[flat_index] = results[device_idx][result_idx].release(); + int flat_index = mesh_id * num_results + result_idx; + op_results[flat_index] = results[mesh_id][result_idx].release(); + if (returned_futures.has_value()) { + future_results[flat_index] = new FutureType((*returned_futures)[mesh_id]); + } } } } @@ -784,10 +849,16 @@ extern "C" void RegisterDialects(MlirContext cctx) { #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" + extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::DialectRegistry ®istry = *unwrap(creg); prepareRegistry(registry); + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); + mlir::LLVM::registerInlinerInterface(registry); + mlir::registerenzymePasses(); enzyme::registerenzymexlaPasses(); @@ -803,10 +874,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); - mlir::registerLLVMDialectImport(registry); - mlir::registerNVVMDialectImport(registry); - mlir::LLVM::registerInlinerInterface(registry); - /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); @@ -827,6 +894,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::transform::registerInterpreterPass(); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); + + // xla + shardy specific passes + xla::sdy::registerSdyRoundTripExportPipeline(); + xla::sdy::registerSdyRoundTripImportPipeline(); } /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric @@ -881,12 +952,6 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, return success(); } -extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos, - MlirStringRef name, MlirAttribute attr) { - llvm::cast(unwrap(op)) - .setArgAttr(pos, unwrap(name), unwrap(attr)); -} - extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char *entryfn) { auto prevMod = cast(*unwrap(prevModC)); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 04818b586a..2cc70c43ce 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -459,6 +459,11 @@ cc_library( "-Wl,-exported_symbol,_XLAExecuteSharded", "-Wl,-exported_symbol,_ClientGetPlatformName", "-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetClient", +"-Wl,-exported_symbol,_ReactantFuncSetResultAttr", +"-Wl,-exported_symbol,_BufferShape", +"-Wl,-exported_symbol,_BufferNDimensions", +"-Wl,-exported_symbol,_BufferPrimitiveType", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 17331fbce3..1cf8489877 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "b8b5037d0d3c108eb374218961631740daa10e05" +ENZYMEXLA_COMMIT = "8d3ed1d53a499841d21b0e90f3201674acfee18a" ENZYMEXLA_SHA256 = "" http_archive(