Skip to content

Commit 2582a33

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Explicitly raise an error if more than 65535 channels are created
`xla::HostCallbackArgInfo` uses `uint16_t` for channel ids, so we should warn users explicitly when the channel ids exceed the UINT16_MAX instead of silently wrapping around. PiperOrigin-RevId: 695682871
1 parent 15f30a9 commit 2582a33

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,13 @@ def backend(self) -> xb.XlaBackend:
770770
return self.backend_or_name
771771

772772
def new_channel(self) -> int:
773-
return next(self.channel_iterator)
773+
channel = next(self.channel_iterator)
774+
# `xla::HostCallback` requires a 16-bit channel ID.
775+
if channel >= (1 << 16):
776+
raise RuntimeError(
777+
"Host callback lowering created too many channels. PjRt does not"
778+
" support more than 65535 channels")
779+
return channel
774780

775781
# Adds an IFRT host callback object to the context. A reference to these
776782
# callbacks will be provided to IFRT during compilation so it can do things

0 commit comments

Comments
 (0)