Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 118 additions & 53 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
// shardy
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/integrations/c/attributes.h"
#include "xla/pjrt/mlir_to_hlo.h"

// IFRT
#include "xla/python/ifrt/array.h"
Expand Down Expand Up @@ -160,6 +161,19 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
// TODO mlirComplexAttrGetnValue
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
// wrap(complex::NumberAttr::getTypeID()); }

extern "C" void ReactantFuncSetResultAttr(MlirOperation op, intptr_t pos,
MlirStringRef name, MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setResultAttr(pos, unwrap(name), unwrap(attr));
}

extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
MlirStringRef name, MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setArgAttr(pos, unwrap(name), unwrap(attr));
}

#pragma endregion

// auxiliar functions
Expand Down Expand Up @@ -438,11 +452,27 @@ extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) {
return Buffer->client();
}

extern "C" absl::Span<const int64_t> BufferShape(PjRtBuffer *Buffer) {
return Buffer->dimensions();
}

extern "C" int64_t BufferNDimensions(PjRtBuffer *Buffer) {
return Buffer->dimensions().length();
}

extern "C" xla::PrimitiveType BufferPrimitiveType(PjRtBuffer *Buffer) {
return Buffer->element_type();
}

extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; }

extern "C" PjRtClient *DeviceToClient(PjRtDevice *Device) {
return Device->client();
}

extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; }
extern "C" PjRtClient *PjRtLoadedExecutableGetClient(PjRtLoadedExecutable *exec) {
return exec->client();
}

// https://openxla.org/xla/shapes
// This minor-to-major dimension order of 0 up to N-1 is akin to column-major
Expand Down Expand Up @@ -593,33 +623,60 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
return wrap(res);
}

/* Note that this */
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
MlirModule cmod,
int *global_ordinals,
int num_global_ordinals,
int64_t device_id,
bool is_sharded,
// const int64_t *mesh_shape,
// int64_t num_mesh_shape,
const int64_t *mesh_ids,
int64_t num_mesh_ids,
const char* xla_gpu_cuda_data_dir) {
auto program =
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));

CompileOptions options;
options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);

// https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
int device_count = client->addressable_device_count();
auto cmodop = cast<ModuleOp>(*unwrap(cmod));

options.executable_build_options.set_num_replicas(device_count);
options.executable_build_options.set_num_partitions(1);
options.executable_build_options.mutable_debug_options()->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);
if (is_sharded) {
assert(device_id < 0);

xla::DeviceAssignment device_assignment(device_count, 1);
for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {
int ordinal = global_ordinals[device_id];
if (ordinal < 0) {
continue;
options.executable_build_options.set_num_replicas(1);
options.executable_build_options.set_num_partitions(num_mesh_ids);

options.executable_build_options.set_use_spmd_partitioning(true);
options.executable_build_options.set_use_shardy_partitioner(true);

// auto partitioning for GPUs is not available in open source version of XLA
// options.executable_build_options.set_use_auto_spmd_partitioning(true);
// std::vector<int64_t> mesh_shape_vec(mesh_shape, mesh_shape + num_mesh_shape);
// options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec);
// std::vector<int64_t> mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids);
// options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec);

xla::DeviceAssignment device_assignment(1, num_mesh_ids);
for (int64_t i = 0; i < num_mesh_ids; ++i) {
int64_t mesh_id = mesh_ids[i];
assert(mesh_id >= 0);
device_assignment(0, mesh_id) = i;
}
device_assignment(ordinal, 0) = device_id;
options.executable_build_options.set_device_assignment(device_assignment);

// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
xla::ExportShardyForHloRoundTrip(cmodop);
} else {
assert(device_id >= 0);

options.executable_build_options.set_num_replicas(1);
options.executable_build_options.set_num_partitions(1);
options.executable_build_options.set_device_ordinal(device_id);

xla::DeviceAssignment device_assignment(1, 1);
device_assignment(0, 0) = device_id;
options.executable_build_options.set_device_assignment(device_assignment);
}
options.executable_build_options.set_device_assignment(device_assignment);

auto addressable_devices = client->addressable_devices();
if (!addressable_devices.empty()) {
Expand All @@ -633,8 +690,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
options.executable_build_options.set_device_memory_size(*stats->bytes_limit);
}
}
auto exec =
MyValueOrThrow(client->Compile(cast<ModuleOp>(*unwrap(cmod)), options));
auto exec = MyValueOrThrow(client->Compile(cmodop, options));
return exec.release();
}

Expand Down Expand Up @@ -694,23 +750,33 @@ extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
}
}

extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
PjRtBuffer **op_args,
const int64_t *mesh_ids, int64_t num_mesh_ids,
uint8_t *is_arg_donatable,
int num_results, PjRtBuffer **op_results,
uint8_t *futures, FutureType **future_results) {
auto client = exec->client();
int num_devices = client->addressable_device_count();

// Ensure argument_handles is structured as num_devices x num_args
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);
// Ensure argument_handles is structured as num_mesh_ids x num_args
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_mesh_ids);
int num_args = op_args_len / num_mesh_ids;

// Distribute arguments across devices
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
argument_handles[device_idx].reserve(num_args);
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
int64_t mesh_id = mesh_ids[device_idx];

// Validate mesh_id
if (mesh_id < 0 || mesh_id >= num_mesh_ids) {
ReactantThrowError(("Invalid mesh_id " + std::to_string(mesh_id) + " at device_idx " +
std::to_string(device_idx)).c_str());
}

argument_handles[mesh_id].reserve(num_args);
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
// Assuming op_args is a flat array of size num_devices * num_args
// where arguments for each device are contiguous
argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]);
argument_handles[mesh_id].push_back(op_args[mesh_id * num_args + arg_idx]);
}
}

Expand All @@ -722,41 +788,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
}
options.untuple_result = true;

std::optional<std::vector<FutureType>> returned_futures;
std::optional<std::vector<FutureType>> returned_futures = std::vector<FutureType>();
auto results = MyValueOrThrow(
exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer *>>>(
argument_handles),
options, returned_futures));

assert(results.size() == num_devices);
assert(results.size() == num_mesh_ids);

for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
if (results[device_idx].size() != num_results) {
llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size()
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
int64_t mesh_id = mesh_ids[device_idx];
if (results[mesh_id].size() != num_results) {
llvm::errs() << " results[" << mesh_id << "].size()=" << results[mesh_id].size()
<< " num_results=" << num_results << "\n";
}
assert(results[device_idx].size() == num_results);
assert(results[mesh_id].size() == num_results);
}

// Handle returned futures
if (returned_futures) {
if (returned_futures.has_value()) {
*futures = true;
assert(returned_futures->size() == num_devices * num_results);
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
future_results[flat_index] = new FutureType((*returned_futures)[flat_index]);
}
}
assert(returned_futures->size() == num_mesh_ids);
} else {
*futures = false;
}

// Copy results into the output buffers
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
int64_t mesh_id = mesh_ids[device_idx];
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
op_results[flat_index] = results[device_idx][result_idx].release();
int flat_index = mesh_id * num_results + result_idx;
op_results[flat_index] = results[mesh_id][result_idx].release();
if (returned_futures.has_value()) {
future_results[flat_index] = new FutureType((*returned_futures)[mesh_id]);
}
}
}
}
Expand Down Expand Up @@ -784,10 +849,16 @@ extern "C" void RegisterDialects(MlirContext cctx) {
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"

extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);
prepareRegistry(registry);

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

mlir::registerenzymePasses();
enzyme::registerenzymexlaPasses();

Expand All @@ -803,10 +874,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);
mlir::LLVM::registerInlinerInterface(registry);

/*
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand All @@ -827,6 +894,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::transform::registerInterpreterPass();
mlir::enzyme::registerGenerateApplyPatternsPass();
mlir::enzyme::registerRemoveTransformPass();

// xla + shardy specific passes
xla::sdy::registerSdyRoundTripExportPipeline();
xla::sdy::registerSdyRoundTripImportPipeline();
}

/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
Expand Down Expand Up @@ -881,12 +952,6 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
return success();
}

extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
MlirStringRef name, MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setArgAttr(pos, unwrap(name), unwrap(attr));
}

extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
const char *entryfn) {
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
Expand Down
5 changes: 5 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ cc_library(
"-Wl,-exported_symbol,_XLAExecuteSharded",
"-Wl,-exported_symbol,_ClientGetPlatformName",
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
"-Wl,-exported_symbol,_PjRtLoadedExecutableGetClient",
"-Wl,-exported_symbol,_ReactantFuncSetResultAttr",
"-Wl,-exported_symbol,_BufferShape",
"-Wl,-exported_symbol,_BufferNDimensions",
"-Wl,-exported_symbol,_BufferPrimitiveType",
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "b8b5037d0d3c108eb374218961631740daa10e05"
ENZYMEXLA_COMMIT = "8d3ed1d53a499841d21b0e90f3201674acfee18a"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
Loading