diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 3edd7439644c..c416d576abf3 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -1273,6 +1273,9 @@ nanobind_pywrap_extension( "@shardy//shardy/dialect/mpmd/ir:fragment_execution_rules", "@shardy//shardy/dialect/mpmd/transforms/import:mesh_assignment_map", "@shardy//shardy/integrations/python/jax/mpmd/jaxlib:mpmd_program", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python/ifrt/ir/conversions/mpmd:lower_to_ifrt", ], ) diff --git a/jaxlib/sdy_mpmd.cc b/jaxlib/sdy_mpmd.cc index 65190386b6cc..b79ea291fb91 100644 --- a/jaxlib/sdy_mpmd.cc +++ b/jaxlib/sdy_mpmd.cc @@ -22,9 +22,13 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep; Needed to allow MlirModule -> ModuleOp. #include "mlir/CAPI/IR.h" // IWYU pragma: keep; Needed to allow MlirModule -> ModuleOp. +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" #include "nanobind/nanobind.h" // IWYU pragma: begin_keep; Nanobind conversions for std types. #include "nanobind/stl/map.h" @@ -38,6 +42,9 @@ limitations under the License. #include "shardy/dialect/mpmd/ir/fragment_execution_rules.h" #include "shardy/dialect/mpmd/ir/utils.h" #include "shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h" +#include "xla/pjrt/status_casters.h" // IWYU pragma: keep; Needed for ValueOrThrow +#include "xla/python/ifrt/ir/conversions/mpmd/lower_to_ifrt.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep namespace nb = nanobind; @@ -61,6 +68,9 @@ using ::mlir::mpmd::PartitioningResult; using ::mlir::mpmd::SplitFragmentType; using ::mlir::mpmd::SpmdTensorPartitionSpec; using ::mlir::mpmd::UserAssignmentMap; +using ::xla::ifrt::mpmd::EnvOptionsOverride; +using ::xla::ifrt::mpmd::GetCompileOptions; +using ::xla::ifrt::mpmd::LowerToIfrt; // Wrapper of PartitioningResult, which stores MlirModules instead of ModuleOps. struct PartitioningResultWrapper { @@ -234,6 +244,28 @@ NB_MODULE(_sdy_mpmd, m) { }, nb::arg("c_module"), nb::arg("unit_attributes") = std::vector()); + + m.def( + "lower_to_ifrt", + [](MlirModule module) -> void { + return xla::ThrowIfError(LowerToIfrt(unwrap(module))); + }, + nb::arg("module")); + + m.def("get_compile_options", + [](MlirModule c_module, + const absl::flat_hash_map& + compile_options_overrides) -> absl::StatusOr { + auto module = unwrap(c_module); + auto compile_options_map = ValueOrThrow( + GetCompileOptions(module, compile_options_overrides)); + nb::dict out; + for (const auto& [name, options] : compile_options_map) { + out[nb::cast(name)] = + nb::steal(nanobind::cast(options).release().ptr()); + } + return out; + }); } } // namespace