Skip to content

Commit c325683

Browse files
authored
feat: API changes for multi-device execution [ReactantExtra JLL changes] (#692)
* feat: API changes for multi-device execution * chore: bump EnzymeJAX commit * fix: for now don't use futures * fix: revert no future * fix: rename attr setter * feat: expose some Buffer functions
1 parent d842b33 commit c325683

File tree

3 files changed

+124
-54
lines changed

3 files changed

+124
-54
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 118 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
// shardy
7272
#include "shardy/dialect/sdy/ir/dialect.h"
7373
#include "shardy/integrations/c/attributes.h"
74+
#include "xla/pjrt/mlir_to_hlo.h"
7475

7576
// IFRT
7677
#include "xla/python/ifrt/array.h"
@@ -160,6 +161,19 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
160161
// TODO mlirComplexAttrGetnValue
161162
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
162163
// wrap(complex::NumberAttr::getTypeID()); }
164+
165+
extern "C" void ReactantFuncSetResultAttr(MlirOperation op, intptr_t pos,
166+
MlirStringRef name, MlirAttribute attr) {
167+
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
168+
.setResultAttr(pos, unwrap(name), unwrap(attr));
169+
}
170+
171+
extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
172+
MlirStringRef name, MlirAttribute attr) {
173+
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
174+
.setArgAttr(pos, unwrap(name), unwrap(attr));
175+
}
176+
163177
#pragma endregion
164178

165179
// auxiliar functions
@@ -438,11 +452,27 @@ extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) {
438452
return Buffer->client();
439453
}
440454

455+
extern "C" absl::Span<const int64_t> BufferShape(PjRtBuffer *Buffer) {
456+
return Buffer->dimensions();
457+
}
458+
459+
extern "C" int64_t BufferNDimensions(PjRtBuffer *Buffer) {
460+
return Buffer->dimensions().length();
461+
}
462+
463+
extern "C" xla::PrimitiveType BufferPrimitiveType(PjRtBuffer *Buffer) {
464+
return Buffer->element_type();
465+
}
466+
467+
extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; }
468+
441469
extern "C" PjRtClient *DeviceToClient(PjRtDevice *Device) {
442470
return Device->client();
443471
}
444472

445-
extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; }
473+
extern "C" PjRtClient *PjRtLoadedExecutableGetClient(PjRtLoadedExecutable *exec) {
474+
return exec->client();
475+
}
446476

447477
// https://openxla.org/xla/shapes
448478
// This minor-to-major dimension order of 0 up to N-1 is akin to column-major
@@ -593,33 +623,60 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
593623
return wrap(res);
594624
}
595625

596-
/* Note that this */
597626
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
598627
MlirModule cmod,
599-
int *global_ordinals,
600-
int num_global_ordinals,
628+
int64_t device_id,
629+
bool is_sharded,
630+
// const int64_t *mesh_shape,
631+
// int64_t num_mesh_shape,
632+
const int64_t *mesh_ids,
633+
int64_t num_mesh_ids,
601634
const char* xla_gpu_cuda_data_dir) {
602635
auto program =
603636
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));
604637

605638
CompileOptions options;
639+
options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);
606640

607-
// https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
608-
int device_count = client->addressable_device_count();
641+
auto cmodop = cast<ModuleOp>(*unwrap(cmod));
609642

610-
options.executable_build_options.set_num_replicas(device_count);
611-
options.executable_build_options.set_num_partitions(1);
612-
options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);
643+
if (is_sharded) {
644+
assert(device_id < 0);
613645

614-
xla::DeviceAssignment device_assignment(device_count, 1);
615-
for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {
616-
int ordinal = global_ordinals[device_id];
617-
if (ordinal < 0) {
618-
continue;
646+
options.executable_build_options.set_num_replicas(1);
647+
options.executable_build_options.set_num_partitions(num_mesh_ids);
648+
649+
options.executable_build_options.set_use_spmd_partitioning(true);
650+
options.executable_build_options.set_use_shardy_partitioner(true);
651+
652+
// auto partitioning for GPUs is not available in open source version of XLA
653+
// options.executable_build_options.set_use_auto_spmd_partitioning(true);
654+
// std::vector<int64_t> mesh_shape_vec(mesh_shape, mesh_shape + num_mesh_shape);
655+
// options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec);
656+
// std::vector<int64_t> mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids);
657+
// options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec);
658+
659+
xla::DeviceAssignment device_assignment(1, num_mesh_ids);
660+
for (int64_t i = 0; i < num_mesh_ids; ++i) {
661+
int64_t mesh_id = mesh_ids[i];
662+
assert(mesh_id >= 0);
663+
device_assignment(0, mesh_id) = i;
619664
}
620-
device_assignment(ordinal, 0) = device_id;
665+
options.executable_build_options.set_device_assignment(device_assignment);
666+
667+
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
668+
xla::ExportShardyForHloRoundTrip(cmodop);
669+
} else {
670+
assert(device_id >= 0);
671+
672+
options.executable_build_options.set_num_replicas(1);
673+
options.executable_build_options.set_num_partitions(1);
674+
options.executable_build_options.set_device_ordinal(device_id);
675+
676+
xla::DeviceAssignment device_assignment(1, 1);
677+
device_assignment(0, 0) = device_id;
678+
options.executable_build_options.set_device_assignment(device_assignment);
621679
}
622-
options.executable_build_options.set_device_assignment(device_assignment);
623680

624681
auto addressable_devices = client->addressable_devices();
625682
if (!addressable_devices.empty()) {
@@ -633,8 +690,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
633690
options.executable_build_options.set_device_memory_size(*stats->bytes_limit);
634691
}
635692
}
636-
auto exec =
637-
MyValueOrThrow(client->Compile(cast<ModuleOp>(*unwrap(cmod)), options));
693+
auto exec = MyValueOrThrow(client->Compile(cmodop, options));
638694
return exec.release();
639695
}
640696

@@ -694,23 +750,33 @@ extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
694750
}
695751
}
696752

697-
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
698-
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
753+
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
754+
PjRtBuffer **op_args,
755+
const int64_t *mesh_ids, int64_t num_mesh_ids,
756+
uint8_t *is_arg_donatable,
699757
int num_results, PjRtBuffer **op_results,
700758
uint8_t *futures, FutureType **future_results) {
701759
auto client = exec->client();
702-
int num_devices = client->addressable_device_count();
703760

704-
// Ensure argument_handles is structured as num_devices x num_args
705-
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);
761+
// Ensure argument_handles is structured as num_mesh_ids x num_args
762+
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_mesh_ids);
763+
int num_args = op_args_len / num_mesh_ids;
706764

707765
// Distribute arguments across devices
708-
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
709-
argument_handles[device_idx].reserve(num_args);
766+
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
767+
int64_t mesh_id = mesh_ids[device_idx];
768+
769+
// Validate mesh_id
770+
if (mesh_id < 0 || mesh_id >= num_mesh_ids) {
771+
ReactantThrowError(("Invalid mesh_id " + std::to_string(mesh_id) + " at device_idx " +
772+
std::to_string(device_idx)).c_str());
773+
}
774+
775+
argument_handles[mesh_id].reserve(num_args);
710776
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
711777
// Assuming op_args is a flat array of size num_devices * num_args
712778
// where arguments for each device are contiguous
713-
argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]);
779+
argument_handles[mesh_id].push_back(op_args[mesh_id * num_args + arg_idx]);
714780
}
715781
}
716782

@@ -722,41 +788,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
722788
}
723789
options.untuple_result = true;
724790

725-
std::optional<std::vector<FutureType>> returned_futures;
791+
std::optional<std::vector<FutureType>> returned_futures = std::vector<FutureType>();
726792
auto results = MyValueOrThrow(
727793
exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer *>>>(
728794
argument_handles),
729795
options, returned_futures));
730796

731-
assert(results.size() == num_devices);
797+
assert(results.size() == num_mesh_ids);
732798

733-
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
734-
if (results[device_idx].size() != num_results) {
735-
llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size()
799+
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
800+
int64_t mesh_id = mesh_ids[device_idx];
801+
if (results[mesh_id].size() != num_results) {
802+
llvm::errs() << " results[" << mesh_id << "].size()=" << results[mesh_id].size()
736803
<< " num_results=" << num_results << "\n";
737804
}
738-
assert(results[device_idx].size() == num_results);
805+
assert(results[mesh_id].size() == num_results);
739806
}
740807

741808
// Handle returned futures
742-
if (returned_futures) {
809+
if (returned_futures.has_value()) {
743810
*futures = true;
744-
assert(returned_futures->size() == num_devices * num_results);
745-
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
746-
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
747-
int flat_index = device_idx * num_results + result_idx;
748-
future_results[flat_index] = new FutureType((*returned_futures)[flat_index]);
749-
}
750-
}
811+
assert(returned_futures->size() == num_mesh_ids);
751812
} else {
752813
*futures = false;
753814
}
754815

755816
// Copy results into the output buffers
756-
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
817+
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
818+
int64_t mesh_id = mesh_ids[device_idx];
757819
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
758-
int flat_index = device_idx * num_results + result_idx;
759-
op_results[flat_index] = results[device_idx][result_idx].release();
820+
int flat_index = mesh_id * num_results + result_idx;
821+
op_results[flat_index] = results[mesh_id][result_idx].release();
822+
if (returned_futures.has_value()) {
823+
future_results[flat_index] = new FutureType((*returned_futures)[mesh_id]);
824+
}
760825
}
761826
}
762827
}
@@ -784,10 +849,16 @@ extern "C" void RegisterDialects(MlirContext cctx) {
784849
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
785850
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
786851
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
852+
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
853+
787854
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
788855
mlir::DialectRegistry &registry = *unwrap(creg);
789856
prepareRegistry(registry);
790857

858+
mlir::registerLLVMDialectImport(registry);
859+
mlir::registerNVVMDialectImport(registry);
860+
mlir::LLVM::registerInlinerInterface(registry);
861+
791862
mlir::registerenzymePasses();
792863
enzyme::registerenzymexlaPasses();
793864

@@ -803,10 +874,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
803874
mlir::affine::registerAffinePasses();
804875
mlir::registerReconcileUnrealizedCasts();
805876

806-
mlir::registerLLVMDialectImport(registry);
807-
mlir::registerNVVMDialectImport(registry);
808-
mlir::LLVM::registerInlinerInterface(registry);
809-
810877
/*
811878
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
812879
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -827,6 +894,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
827894
mlir::transform::registerInterpreterPass();
828895
mlir::enzyme::registerGenerateApplyPatternsPass();
829896
mlir::enzyme::registerRemoveTransformPass();
897+
898+
// xla + shardy specific passes
899+
xla::sdy::registerSdyRoundTripExportPipeline();
900+
xla::sdy::registerSdyRoundTripImportPipeline();
830901
}
831902

832903
/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
@@ -881,12 +952,6 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
881952
return success();
882953
}
883954

884-
extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
885-
MlirStringRef name, MlirAttribute attr) {
886-
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
887-
.setArgAttr(pos, unwrap(name), unwrap(attr));
888-
}
889-
890955
extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
891956
const char *entryfn) {
892957
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));

deps/ReactantExtra/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,11 @@ cc_library(
459459
"-Wl,-exported_symbol,_XLAExecuteSharded",
460460
"-Wl,-exported_symbol,_ClientGetPlatformName",
461461
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
462+
"-Wl,-exported_symbol,_PjRtLoadedExecutableGetClient",
463+
"-Wl,-exported_symbol,_ReactantFuncSetResultAttr",
464+
"-Wl,-exported_symbol,_BufferShape",
465+
"-Wl,-exported_symbol,_BufferNDimensions",
466+
"-Wl,-exported_symbol,_BufferPrimitiveType",
462467
]}),
463468
deps = [
464469
"@enzyme//:EnzymeMLIR",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "b8b5037d0d3c108eb374218961631740daa10e05"
12+
ENZYMEXLA_COMMIT = "8d3ed1d53a499841d21b0e90f3201674acfee18a"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

0 commit comments

Comments
 (0)