7171// shardy
7272#include " shardy/dialect/sdy/ir/dialect.h"
7373#include " shardy/integrations/c/attributes.h"
74+ #include " xla/pjrt/mlir_to_hlo.h"
7475
7576// IFRT
7677#include " xla/python/ifrt/array.h"
@@ -160,6 +161,19 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
160161// TODO mlirComplexAttrGetnValue
161162// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
162163// wrap(complex::NumberAttr::getTypeID()); }
164+
165+ extern " C" void ReactantFuncSetResultAttr (MlirOperation op, intptr_t pos,
166+ MlirStringRef name, MlirAttribute attr) {
167+ llvm::cast<mlir::FunctionOpInterface>(unwrap (op))
168+ .setResultAttr (pos, unwrap (name), unwrap (attr));
169+ }
170+
171+ extern " C" void ReactantFuncSetArgAttr (MlirOperation op, intptr_t pos,
172+ MlirStringRef name, MlirAttribute attr) {
173+ llvm::cast<mlir::FunctionOpInterface>(unwrap (op))
174+ .setArgAttr (pos, unwrap (name), unwrap (attr));
175+ }
176+
163177#pragma endregion
164178
165179// auxiliar functions
@@ -438,11 +452,27 @@ extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) {
438452 return Buffer->client ();
439453}
440454
455+ extern " C" absl::Span<const int64_t > BufferShape (PjRtBuffer *Buffer) {
456+ return Buffer->dimensions ();
457+ }
458+
459+ extern " C" int64_t BufferNDimensions (PjRtBuffer *Buffer) {
460+ return Buffer->dimensions ().length ();
461+ }
462+
463+ extern " C" xla::PrimitiveType BufferPrimitiveType (PjRtBuffer *Buffer) {
464+ return Buffer->element_type ();
465+ }
466+
467+ extern " C" void PjRtBufferFree (PjRtBuffer *Buffer) { delete Buffer; }
468+
441469extern " C" PjRtClient *DeviceToClient (PjRtDevice *Device) {
442470 return Device->client ();
443471}
444472
445- extern " C" void PjRtBufferFree (PjRtBuffer *Buffer) { delete Buffer; }
473+ extern " C" PjRtClient *PjRtLoadedExecutableGetClient (PjRtLoadedExecutable *exec) {
474+ return exec->client ();
475+ }
446476
447477// https://openxla.org/xla/shapes
448478// This minor-to-major dimension order of 0 up to N-1 is akin to column-major
@@ -593,33 +623,60 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
593623 return wrap (res);
594624}
595625
596- /* Note that this */
597626extern " C" xla::PjRtLoadedExecutable *ClientCompile (PjRtClient *client,
598627 MlirModule cmod,
599- int *global_ordinals,
600- int num_global_ordinals,
628+ int64_t device_id,
629+ bool is_sharded,
630+ // const int64_t *mesh_shape,
631+ // int64_t num_mesh_shape,
632+ const int64_t *mesh_ids,
633+ int64_t num_mesh_ids,
601634 const char * xla_gpu_cuda_data_dir) {
602635 auto program =
603636 std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap (cmod)));
604637
605638 CompileOptions options;
639+ options.executable_build_options .mutable_debug_options ()->set_xla_gpu_cuda_data_dir (xla_gpu_cuda_data_dir);
606640
607- // https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
608- int device_count = client->addressable_device_count ();
641+ auto cmodop = cast<ModuleOp>(*unwrap (cmod));
609642
610- options.executable_build_options .set_num_replicas (device_count);
611- options.executable_build_options .set_num_partitions (1 );
612- options.executable_build_options .mutable_debug_options ()->set_xla_gpu_cuda_data_dir (xla_gpu_cuda_data_dir);
643+ if (is_sharded) {
644+ assert (device_id < 0 );
613645
614- xla::DeviceAssignment device_assignment (device_count, 1 );
615- for (int64_t device_id = 0 ; device_id < num_global_ordinals; ++device_id) {
616- int ordinal = global_ordinals[device_id];
617- if (ordinal < 0 ) {
618- continue ;
646+ options.executable_build_options .set_num_replicas (1 );
647+ options.executable_build_options .set_num_partitions (num_mesh_ids);
648+
649+ options.executable_build_options .set_use_spmd_partitioning (true );
650+ options.executable_build_options .set_use_shardy_partitioner (true );
651+
652+ // auto partitioning for GPUs is not available in open source version of XLA
653+ // options.executable_build_options.set_use_auto_spmd_partitioning(true);
654+ // std::vector<int64_t> mesh_shape_vec(mesh_shape, mesh_shape + num_mesh_shape);
655+ // options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec);
656+ // std::vector<int64_t> mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids);
657+ // options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec);
658+
659+ xla::DeviceAssignment device_assignment (1 , num_mesh_ids);
660+ for (int64_t i = 0 ; i < num_mesh_ids; ++i) {
661+ int64_t mesh_id = mesh_ids[i];
662+ assert (mesh_id >= 0 );
663+ device_assignment (0 , mesh_id) = i;
619664 }
620- device_assignment (ordinal, 0 ) = device_id;
665+ options.executable_build_options .set_device_assignment (device_assignment);
666+
667+ // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
668+ xla::ExportShardyForHloRoundTrip (cmodop);
669+ } else {
670+ assert (device_id >= 0 );
671+
672+ options.executable_build_options .set_num_replicas (1 );
673+ options.executable_build_options .set_num_partitions (1 );
674+ options.executable_build_options .set_device_ordinal (device_id);
675+
676+ xla::DeviceAssignment device_assignment (1 , 1 );
677+ device_assignment (0 , 0 ) = device_id;
678+ options.executable_build_options .set_device_assignment (device_assignment);
621679 }
622- options.executable_build_options .set_device_assignment (device_assignment);
623680
624681 auto addressable_devices = client->addressable_devices ();
625682 if (!addressable_devices.empty ()) {
@@ -633,8 +690,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
633690 options.executable_build_options .set_device_memory_size (*stats->bytes_limit );
634691 }
635692 }
636- auto exec =
637- MyValueOrThrow (client->Compile (cast<ModuleOp>(*unwrap (cmod)), options));
693+ auto exec = MyValueOrThrow (client->Compile (cmodop, options));
638694 return exec.release ();
639695}
640696
@@ -694,23 +750,33 @@ extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
694750 }
695751}
696752
697- extern " C" void XLAExecute (xla::PjRtLoadedExecutable *exec, int num_args,
698- PjRtBuffer **op_args, uint8_t *is_arg_donatable,
753+ extern " C" void XLAExecute (xla::PjRtLoadedExecutable *exec, int op_args_len,
754+ PjRtBuffer **op_args,
755+ const int64_t *mesh_ids, int64_t num_mesh_ids,
756+ uint8_t *is_arg_donatable,
699757 int num_results, PjRtBuffer **op_results,
700758 uint8_t *futures, FutureType **future_results) {
701759 auto client = exec->client ();
702- int num_devices = client->addressable_device_count ();
703760
704- // Ensure argument_handles is structured as num_devices x num_args
705- std::vector<std::vector<PjRtBuffer *>> argument_handles (num_devices);
761+ // Ensure argument_handles is structured as num_mesh_ids x num_args
762+ std::vector<std::vector<PjRtBuffer *>> argument_handles (num_mesh_ids);
763+ int num_args = op_args_len / num_mesh_ids;
706764
707765 // Distribute arguments across devices
708- for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
709- argument_handles[device_idx].reserve (num_args);
766+ for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
767+ int64_t mesh_id = mesh_ids[device_idx];
768+
769+ // Validate mesh_id
770+ if (mesh_id < 0 || mesh_id >= num_mesh_ids) {
771+ ReactantThrowError ((" Invalid mesh_id " + std::to_string (mesh_id) + " at device_idx " +
772+ std::to_string (device_idx)).c_str ());
773+ }
774+
775+ argument_handles[mesh_id].reserve (num_args);
710776 for (int arg_idx = 0 ; arg_idx < num_args; ++arg_idx) {
711777 // Assuming op_args is a flat array of size num_devices * num_args
712778 // where arguments for each device are contiguous
713- argument_handles[device_idx ].push_back (op_args[device_idx * num_args + arg_idx]);
779+ argument_handles[mesh_id ].push_back (op_args[mesh_id * num_args + arg_idx]);
714780 }
715781 }
716782
@@ -722,41 +788,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
722788 }
723789 options.untuple_result = true ;
724790
725- std::optional<std::vector<FutureType>> returned_futures;
791+ std::optional<std::vector<FutureType>> returned_futures = std::vector<FutureType>() ;
726792 auto results = MyValueOrThrow (
727793 exec->Execute (static_cast <absl::Span<const std::vector<PjRtBuffer *>>>(
728794 argument_handles),
729795 options, returned_futures));
730796
731- assert (results.size () == num_devices );
797+ assert (results.size () == num_mesh_ids );
732798
733- for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
734- if (results[device_idx].size () != num_results) {
735- llvm::errs () << " results[" << device_idx << " ].size()=" << results[device_idx].size ()
799+ for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
800+ int64_t mesh_id = mesh_ids[device_idx];
801+ if (results[mesh_id].size () != num_results) {
802+ llvm::errs () << " results[" << mesh_id << " ].size()=" << results[mesh_id].size ()
736803 << " num_results=" << num_results << " \n " ;
737804 }
738- assert (results[device_idx ].size () == num_results);
805+ assert (results[mesh_id ].size () == num_results);
739806 }
740807
741808 // Handle returned futures
742- if (returned_futures) {
809+ if (returned_futures. has_value () ) {
743810 *futures = true ;
744- assert (returned_futures->size () == num_devices * num_results);
745- for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
746- for (int result_idx = 0 ; result_idx < num_results; ++result_idx) {
747- int flat_index = device_idx * num_results + result_idx;
748- future_results[flat_index] = new FutureType ((*returned_futures)[flat_index]);
749- }
750- }
811+ assert (returned_futures->size () == num_mesh_ids);
751812 } else {
752813 *futures = false ;
753814 }
754815
755816 // Copy results into the output buffers
756- for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
817+ for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
818+ int64_t mesh_id = mesh_ids[device_idx];
757819 for (int result_idx = 0 ; result_idx < num_results; ++result_idx) {
758- int flat_index = device_idx * num_results + result_idx;
759- op_results[flat_index] = results[device_idx][result_idx].release ();
820+ int flat_index = mesh_id * num_results + result_idx;
821+ op_results[flat_index] = results[mesh_id][result_idx].release ();
822+ if (returned_futures.has_value ()) {
823+ future_results[flat_index] = new FutureType ((*returned_futures)[mesh_id]);
824+ }
760825 }
761826 }
762827}
@@ -784,10 +849,16 @@ extern "C" void RegisterDialects(MlirContext cctx) {
784849#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
785850#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
786851#include " mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
852+ #include " xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
853+
787854extern " C" void InitializeRegistryAndPasses (MlirDialectRegistry creg) {
788855 mlir::DialectRegistry ®istry = *unwrap (creg);
789856 prepareRegistry (registry);
790857
858+ mlir::registerLLVMDialectImport (registry);
859+ mlir::registerNVVMDialectImport (registry);
860+ mlir::LLVM::registerInlinerInterface (registry);
861+
791862 mlir::registerenzymePasses ();
792863 enzyme::registerenzymexlaPasses ();
793864
@@ -803,10 +874,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
803874 mlir::affine::registerAffinePasses ();
804875 mlir::registerReconcileUnrealizedCasts ();
805876
806- mlir::registerLLVMDialectImport (registry);
807- mlir::registerNVVMDialectImport (registry);
808- mlir::LLVM::registerInlinerInterface (registry);
809-
810877 /*
811878 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
812879 LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -827,6 +894,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
827894 mlir::transform::registerInterpreterPass ();
828895 mlir::enzyme::registerGenerateApplyPatternsPass ();
829896 mlir::enzyme::registerRemoveTransformPass ();
897+
898+ // xla + shardy specific passes
899+ xla::sdy::registerSdyRoundTripExportPipeline ();
900+ xla::sdy::registerSdyRoundTripImportPipeline ();
830901}
831902
832903// / Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
@@ -881,12 +952,6 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
881952 return success ();
882953}
883954
884- extern " C" void ReactantFuncSetArgAttr (MlirOperation op, intptr_t pos,
885- MlirStringRef name, MlirAttribute attr) {
886- llvm::cast<mlir::FunctionOpInterface>(unwrap (op))
887- .setArgAttr (pos, unwrap (name), unwrap (attr));
888- }
889-
890955extern " C" MlirOperation LinkInModule (MlirModule prevModC, MlirModule newModC,
891956 const char *entryfn) {
892957 auto prevMod = cast<ModuleOp>(*unwrap (prevModC));
0 commit comments