Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,27 @@ def _validate_jax_pjrt_client_create_options(new_val):
checking_leaks = functools.partial(check_tracer_leaks, True)


def _check_ifrt_user_context_update_global_hook(value: bool):
if jaxlib_extension_version >= 381:
xla_client._xla.set_ifrt_user_context_required_global(value)

def _check_ifrt_user_context_update_thread_local_hook(value: bool | None):
if jaxlib_extension_version >= 381:
xla_client._xla.set_ifrt_user_context_required_thread_local(value)

check_ifrt_user_context = bool_state(
name='jax_check_ifrt_user_context',
default=False,
help=(
'Turn on checking for IFRT user contexts for IFRT values and'
' executables that are wrapped as JAX objects so that they are'
' associated with some traceback. Normally, only JAX tests set this to'
' True for invariant checking.'
),
update_global_hook=_check_ifrt_user_context_update_global_hook,
update_thread_local_hook=_check_ifrt_user_context_update_thread_local_hook,
)

captured_constants_warn_bytes = int_state(
name='jax_captured_constants_warn_bytes',
default=2 * 10 ** 9,
Expand Down
1 change: 1 addition & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,7 @@ class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_global_config: dict[str, Any] = {}
_default_thread_local_config = {
'jax_check_ifrt_user_context': True,
'jax_enable_checks': True,
'jax_numpy_dtype_promotion': 'strict',
'jax_numpy_rank_promotion': 'raise',
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/_jax/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ class Traceback:

def tracebacks_enabled() -> bool: ...
def set_tracebacks_enabled(enabled: bool) -> None: ...
def set_ifrt_user_context_required_global(required: bool) -> None: ...
def set_ifrt_user_context_required_thread_local(
required: bool | None
) -> None: ...

# === END py_traceback.cc

Expand Down
11 changes: 11 additions & 0 deletions jaxlib/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,11 @@ PyArray::PyArray(nb::object aval, bool weak_type, xla::nb_dtype dtype,
nb_class_ptr<PyClient> py_client, ifrt::ArrayRef ifrt_array,
bool committed, bool skip_checks,
xla::PjRtFuture<> result_status) {
if (ifrt_array->user_context() == nullptr && IsIfrtUserContextRequired()) {
throw nb::value_error(
"Expecting an IFRT Array to have a user context, but got a null "
"user context.");
}
auto* self =
PyArray_tp_new(reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr);
m_ptr = self;
Expand Down Expand Up @@ -655,6 +660,12 @@ nb::object PyArray::CheckAndRearrange(const absl::Span<const PyArray> py_arrays,
}

void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) {
if (ifrt_array != nullptr && ifrt_array->user_context() == nullptr &&
IsIfrtUserContextRequired()) {
throw nb::value_error(
"Expecting an IFRT Array to have a user context, but got a null "
"user context.");
}
GetStorage().ifrt_array = std::move(ifrt_array);
}

Expand Down
7 changes: 7 additions & 0 deletions jaxlib/py_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "jaxlib/py_client.h"
#include "jaxlib/py_device.h"
#include "jaxlib/py_user_context.h"
#include "jaxlib/traceback.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
Expand Down Expand Up @@ -132,6 +133,12 @@ PyLoadedExecutable::PyLoadedExecutable(
fingerprint_(std::move(fingerprint)),
next_launch_id_(GetBaseLaunchId(fingerprint_, ifrt_loaded_executable_)) {
CHECK(PyGILState_Check());
if (ifrt_loaded_executable_->user_context() == nullptr &&
IsIfrtUserContextRequired()) {
throw nb::value_error(
"Expecting an IFRT LoadedExecutable to have a user context, but got a "
"null user context.");
}
if (fingerprint_) {
VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name()
<< ": " << *fingerprint_;
Expand Down
23 changes: 23 additions & 0 deletions jaxlib/traceback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ namespace jax {
namespace {

std::atomic<bool> traceback_enabled_ = true;
std::atomic<bool> ifrt_user_context_required_global_ = false;
thread_local std::optional<bool> ifrt_user_context_required_thread_local_;

static constexpr int kMaxFrames = 512;

Expand Down Expand Up @@ -277,6 +279,17 @@ absl::Span<const TracebackEntry> Traceback::RawFrames() const {
return traceback;
}

bool IsIfrtUserContextRequired() {
// If tracebacks are disabled, it is expected that user contexts are nullptr.
if (!traceback_enabled_.load()) {
return false;
}
if (ifrt_user_context_required_thread_local_.has_value()) {
return *ifrt_user_context_required_thread_local_;
}
return ifrt_user_context_required_global_.load();
}

void Traceback::RegisterType(nb::module_& m) {
nb::class_<Traceback::Frame>(m, "Frame")
.def(nb::init<const nb::str&, const nb::str&, int, int>())
Expand Down Expand Up @@ -314,6 +327,16 @@ void Traceback::RegisterType(nb::module_& m) {
m.def("set_tracebacks_enabled",
[](bool value) { traceback_enabled_.store(value); });

m.def("set_ifrt_user_context_required_global", [](bool required) {
return ifrt_user_context_required_global_.store(required);
});
m.def(
"set_ifrt_user_context_required_thread_local",
[](std::optional<bool> required) {
return ifrt_user_context_required_thread_local_ = required;
},
nb::arg("required").none());

type.attr("get_traceback") = nb::cpp_function(Traceback::Get,
R"doc(
Returns a :class:`Traceback` for the current thread.
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/traceback.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class Traceback : public nanobind::object {
static bool Check(PyObject* o);
};

// Whether an IFRT user context must be present in IFRT values and executables
// that are being wrapped as JAX objects.
bool IsIfrtUserContextRequired();

} // namespace jax

#endif // JAXLIB_TRACEBACK_H_
2 changes: 1 addition & 1 deletion jaxlib/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# Please suffix the version number with a brief description of your change
# in a comment. The goal here is to force a merge conflict if two changes
# attempt to grab the same version number.
_version = 380 # Fixed thread safety issue in profiler.
_version = 381 # Added Traceback.set_ifrt_user_context_required_*()

# An internal increasing version number for protecting jaxlib code against
# ifrt changes.
Expand Down
Loading