From 409eb8ff82afbdfbf5efcafafce5dff7a2e6e158 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 8 Oct 2025 13:48:59 -0700 Subject: [PATCH] [JAX] Add `jax_check_ifrt_user_context` config and enable it for JAX tests When JAX config `jax_check_ifrt_user_context` is true, JAX will require some `xla::ifrt::UserContext` to be set up for IFRT values and executables when wrapping them as JAX objects. This replaces `-DIFRT_REQUIRE_USER_CONTEXT` for checking if IFRT user context setups is done correctly. PiperOrigin-RevId: 816845220 --- jax/_src/config.py | 21 +++++++++++++++++++++ jax/_src/test_util.py | 1 + jaxlib/_jax/__init__.pyi | 4 ++++ jaxlib/py_array.cc | 11 +++++++++++ jaxlib/py_executable.cc | 7 +++++++ jaxlib/traceback.cc | 23 +++++++++++++++++++++++ jaxlib/traceback.h | 4 ++++ jaxlib/xla_client.py | 2 +- 8 files changed, 72 insertions(+), 1 deletion(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 93b211d53c5d..253d400b95c0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1224,6 +1224,27 @@ def _validate_jax_pjrt_client_create_options(new_val): checking_leaks = functools.partial(check_tracer_leaks, True) +def _check_ifrt_user_context_update_global_hook(value: bool): + if jaxlib_extension_version >= 381: + xla_client._xla.set_ifrt_user_context_required_global(value) + +def _check_ifrt_user_context_update_thread_local_hook(value: bool | None): + if jaxlib_extension_version >= 381: + xla_client._xla.set_ifrt_user_context_required_thread_local(value) + +check_ifrt_user_context = bool_state( + name='jax_check_ifrt_user_context', + default=False, + help=( + 'Turn on checking for IFRT user contexts for IFRT values and' + ' executables that are wrapped as JAX objects so that they are' + ' associated with some traceback. Normally, only JAX tests set this to' + ' True for invariant checking.' + ), + update_global_hook=_check_ifrt_user_context_update_global_hook, + update_thread_local_hook=_check_ifrt_user_context_update_thread_local_hook, +) + captured_constants_warn_bytes = int_state( name='jax_captured_constants_warn_bytes', default=2 * 10 ** 9, diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 392d8ae3a8ad..222214e056a5 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1227,6 +1227,7 @@ class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" _default_global_config: dict[str, Any] = {} _default_thread_local_config = { + 'jax_check_ifrt_user_context': True, 'jax_enable_checks': True, 'jax_numpy_dtype_promotion': 'strict', 'jax_numpy_rank_promotion': 'raise', diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 1681ab62a7bd..0547b58bdd10 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -793,6 +793,10 @@ class Traceback: def tracebacks_enabled() -> bool: ... def set_tracebacks_enabled(enabled: bool) -> None: ... +def set_ifrt_user_context_required_global(required: bool) -> None: ... +def set_ifrt_user_context_required_thread_local( + required: bool | None +) -> None: ... # === END py_traceback.cc diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 0926b306a509..4d33cc4dfe84 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -626,6 +626,11 @@ PyArray::PyArray(nb::object aval, bool weak_type, xla::nb_dtype dtype, nb_class_ptr py_client, ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, xla::PjRtFuture<> result_status) { + if (ifrt_array->user_context() == nullptr && IsIfrtUserContextRequired()) { + throw nb::value_error( + "Expecting an IFRT Array to have a user context, but got a null " + "user context."); + } auto* self = PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); m_ptr = self; @@ -655,6 +660,12 @@ nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, } void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) { + if (ifrt_array != nullptr && ifrt_array->user_context() == nullptr && + IsIfrtUserContextRequired()) { + throw nb::value_error( + "Expecting an IFRT Array to have a user context, but got a null " + "user context."); + } GetStorage().ifrt_array = std::move(ifrt_array); } diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 7bd50da104cd..62949a454cef 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -42,6 +42,7 @@ limitations under the License. #include "jaxlib/py_client.h" #include "jaxlib/py_device.h" #include "jaxlib/py_user_context.h" +#include "jaxlib/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -132,6 +133,12 @@ PyLoadedExecutable::PyLoadedExecutable( fingerprint_(std::move(fingerprint)), next_launch_id_(GetBaseLaunchId(fingerprint_, ifrt_loaded_executable_)) { CHECK(PyGILState_Check()); + if (ifrt_loaded_executable_->user_context() == nullptr && + IsIfrtUserContextRequired()) { + throw nb::value_error( + "Expecting an IFRT LoadedExecutable to have a user context, but got a " + "null user context."); + } if (fingerprint_) { VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() << ": " << *fingerprint_; diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index 9dbea8fcc5fe..45cf95ab1a85 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -57,6 +57,8 @@ namespace jax { namespace { std::atomic traceback_enabled_ = true; +std::atomic ifrt_user_context_required_global_ = false; +thread_local std::optional ifrt_user_context_required_thread_local_; static constexpr int kMaxFrames = 512; @@ -277,6 +279,17 @@ absl::Span Traceback::RawFrames() const { return traceback; } +bool IsIfrtUserContextRequired() { + // If tracebacks are disabled, it is expected that user contexts are nullptr. + if (!traceback_enabled_.load()) { + return false; + } + if (ifrt_user_context_required_thread_local_.has_value()) { + return *ifrt_user_context_required_thread_local_; + } + return ifrt_user_context_required_global_.load(); +} + void Traceback::RegisterType(nb::module_& m) { nb::class_(m, "Frame") .def(nb::init()) @@ -314,6 +327,16 @@ void Traceback::RegisterType(nb::module_& m) { m.def("set_tracebacks_enabled", [](bool value) { traceback_enabled_.store(value); }); + m.def("set_ifrt_user_context_required_global", [](bool required) { + return ifrt_user_context_required_global_.store(required); + }); + m.def( + "set_ifrt_user_context_required_thread_local", + [](std::optional required) { + return ifrt_user_context_required_thread_local_ = required; + }, + nb::arg("required").none()); + type.attr("get_traceback") = nb::cpp_function(Traceback::Get, R"doc( Returns a :class:`Traceback` for the current thread. diff --git a/jaxlib/traceback.h b/jaxlib/traceback.h index 5a44e853884b..8ed499467d66 100644 --- a/jaxlib/traceback.h +++ b/jaxlib/traceback.h @@ -82,6 +82,10 @@ class Traceback : public nanobind::object { static bool Check(PyObject* o); }; +// Whether an IFRT user context must be present in IFRT values and executables +// that are being wrapped as JAX objects. +bool IsIfrtUserContextRequired(); + } // namespace jax #endif // JAXLIB_TRACEBACK_H_ diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 0a59fa9010fe..5ea236f30200 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 380 # Fixed thread safety issue in profiler. +_version = 381 # Added Traceback.set_ifrt_user_context_required_*() # An internal increasing version number for protecting jaxlib code against # ifrt changes.