Skip to content

Commit 256e37a

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks. PiperOrigin-RevId: 730939406
1 parent 2ce88c9 commit 256e37a

File tree

6 files changed

+222
-145
lines changed

6 files changed

+222
-145
lines changed

jax/_src/api.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import atexit
2626
import collections
27-
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
27+
from collections.abc import Callable, Hashable, Iterable, Sequence
2828
from functools import partial, lru_cache
2929
import inspect
3030
import math
@@ -2736,10 +2736,9 @@ def named_call(
27362736
return source_info_util.extend_name_stack(name)(fun)
27372737

27382738

2739-
@contextmanager
27402739
def named_scope(
27412740
name: str,
2742-
) -> Generator[None, None, None]:
2741+
) -> source_info_util.ExtendNameStackContextManager:
27432742
"""A context manager that adds a user specified name to the JAX name stack.
27442743
27452744
When staging out computations for just-in-time compilation to XLA (or other
@@ -2786,8 +2785,7 @@ def named_scope(
27862785
"""
27872786
if not isinstance(name, str):
27882787
raise TypeError("named_scope name argument must be a string.")
2789-
with source_info_util.extend_name_stack(name):
2790-
yield
2788+
return source_info_util.extend_name_stack(name)
27912789

27922790
def effects_barrier():
27932791
"""Waits until existing functions have completed any side-effects."""

jax/_src/config.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -304,31 +304,8 @@ def _set(self, value: _T) -> None:
304304
if self._update_global_hook:
305305
self._update_global_hook(value)
306306

307-
@contextlib.contextmanager
308307
def __call__(self, new_val: Any = no_default):
309-
if new_val is no_default:
310-
if self._default_context_manager_value is not no_default:
311-
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
312-
else:
313-
# no default_value provided to constructor and no value provided as an
314-
# argument, so we raise an error
315-
raise TypeError(f"Context manager for {self.__name__} config option "
316-
"requires an argument representing the new value for "
317-
"the config option.")
318-
if self._validator:
319-
self._validator(new_val)
320-
prev_val = self.swap_local(new_val)
321-
if self._update_thread_local_hook:
322-
self._update_thread_local_hook(new_val)
323-
try:
324-
yield
325-
finally:
326-
self.set_local(prev_val)
327-
if self._update_thread_local_hook:
328-
if prev_val is config_ext.unset:
329-
self._update_thread_local_hook(None)
330-
else:
331-
self._update_thread_local_hook(cast(Optional[Any], prev_val))
308+
return StateContextManager(self, new_val)
332309

333310
def _add_hooks(self, update_global_hook, update_thread_local_hook):
334311
"""Private method that adds hooks to an existing context-manager.
@@ -339,6 +316,40 @@ def _add_hooks(self, update_global_hook, update_thread_local_hook):
339316
update_global_hook(self.get_global())
340317

341318

319+
class StateContextManager(contextlib.ContextDecorator):
320+
__slots__ = ['state', 'new_val', 'prev']
321+
322+
def __init__(self, state, new_val):
323+
self.state = state
324+
self.new_val = new_val
325+
326+
if new_val is no_default:
327+
if state._default_context_manager_value is not no_default:
328+
new_val = state._default_context_manager_value # default_context_manager_value provided to constructor
329+
else:
330+
# no default_value provided to constructor and no value provided as an
331+
# argument, so we raise an error
332+
raise TypeError(f"Context manager for {state.__name__} config option "
333+
"requires an argument representing the new value for "
334+
"the config option.")
335+
if state._validator:
336+
state._validator(new_val)
337+
338+
339+
def __enter__(self):
340+
self.prev = self.state.swap_local(self.new_val)
341+
if self.state._update_thread_local_hook:
342+
self.state._update_thread_local_hook(self.new_val)
343+
344+
def __exit__(self, exc_type, exc_value, traceback):
345+
self.state.set_local(self.prev)
346+
if self.state._update_thread_local_hook:
347+
if self.prev is config_ext.unset:
348+
self.state._update_thread_local_hook(None)
349+
else:
350+
self.state._update_thread_local_hook(cast(Optional[Any], self.prev))
351+
352+
342353
UPGRADE_BOOL_HELP = (
343354
" This will be enabled by default in future versions of JAX, at which "
344355
"point all uses of the flag will be considered deprecated (following "

jax/_src/core.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,51 +1136,77 @@ def update_thread_local_jit_state(self):
11361136
trace_ctx = TracingContext()
11371137

11381138

1139-
@contextmanager
1140-
def take_current_trace():
1141-
prev = trace_ctx.trace
1142-
try:
1139+
class TakeCurrentTraceContextManager:
1140+
__slots__ = ['prev']
1141+
1142+
def __enter__(self):
1143+
self.prev = trace_ctx.trace
11431144
trace_ctx.set_trace(eval_trace)
1144-
yield prev
1145-
finally:
1146-
trace_ctx.set_trace(prev)
1145+
return self.prev
11471146

1148-
@contextmanager
1149-
def set_current_trace(trace, check_leaks=False):
1150-
prev = trace_ctx.trace
1151-
try:
1152-
trace_ctx.set_trace(trace)
1153-
yield
1154-
finally:
1155-
trace_ctx.set_trace(prev)
1156-
if check_leaks and config.check_tracer_leaks.value:
1157-
trace.invalidate()
1158-
trace_ref = ref(trace)
1159-
del trace
1147+
def __exit__(self, exc_type, exc_value, traceback):
1148+
trace_ctx.set_trace(self.prev)
1149+
1150+
take_current_trace = TakeCurrentTraceContextManager
1151+
1152+
1153+
class SetCurrentTraceContextManager:
1154+
__slots__ = ['trace', 'check_leaks', 'prev']
1155+
1156+
def __init__(self, trace, check_leaks=False):
1157+
self.trace = trace
1158+
self.check_leaks = check_leaks
1159+
1160+
def __enter__(self):
1161+
self.prev = trace_ctx.trace
1162+
trace_ctx.set_trace(self.trace)
1163+
1164+
def __exit__(self, exc_type, exc_value, traceback):
1165+
trace_ctx.set_trace(self.prev)
1166+
if self.check_leaks and config.check_tracer_leaks.value:
1167+
self.trace.invalidate()
1168+
trace_ref = ref(self.trace)
1169+
del self.trace
11601170
live_trace = trace_ref()
11611171
if live_trace is not None:
11621172
leaked_tracers = maybe_find_leaked_tracers(live_trace)
11631173
if leaked_tracers:
11641174
raise leaked_tracer_error("trace", live_trace, leaked_tracers)
11651175

1166-
@contextmanager
1167-
def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
1168-
prev = trace_ctx.axis_env
1169-
try:
1170-
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
1171-
yield
1172-
finally:
1173-
trace_ctx.set_axis_env(prev)
1176+
set_current_trace = SetCurrentTraceContextManager
1177+
1178+
class ExtendAxisEnvNdContextManager:
1179+
__slots__ = ['prev', 'name_size_pairs']
1180+
1181+
def __init__(self, name_size_pairs: Iterable[tuple[AxisName, int]]):
1182+
self.name_size_pairs = name_size_pairs
1183+
1184+
def __enter__(self):
1185+
self.prev = trace_ctx.axis_env
1186+
trace_ctx.set_axis_env(self.prev.extend_pure(self.name_size_pairs))
1187+
1188+
def __exit__(self, exc_type, exc_value, traceback):
1189+
trace_ctx.set_axis_env(self.prev)
1190+
1191+
extend_axis_env_nd = ExtendAxisEnvNdContextManager
1192+
1193+
1194+
class AddSpmdAxisNamesContextManager:
1195+
__slots__ = ['prev', 'axis_names']
1196+
1197+
def __init__(self, axis_names: AxisName | None):
1198+
self.axis_names = axis_names
1199+
1200+
def __enter__(self):
1201+
self.prev = trace_ctx.axis_env
1202+
if self.axis_names is not None:
1203+
trace_ctx.set_axis_env(self.prev.add_spmd_axis_names(self.axis_names))
1204+
1205+
def __exit__(self, exc_type, exc_value, traceback):
1206+
trace_ctx.set_axis_env(self.prev)
1207+
1208+
add_spmd_axis_names = AddSpmdAxisNamesContextManager
11741209

1175-
@contextmanager
1176-
def add_spmd_axis_names(axis_names: AxisName | None):
1177-
prev = trace_ctx.axis_env
1178-
try:
1179-
if axis_names is not None:
1180-
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
1181-
yield
1182-
finally:
1183-
trace_ctx.set_axis_env(prev)
11841210

11851211
def get_axis_env():
11861212
return trace_ctx.axis_env
@@ -1233,6 +1259,7 @@ def ensure_no_leaks(trace:Trace):
12331259
if leaked_tracers:
12341260
raise leaked_tracer_error("trace", live_trace, leaked_tracers)
12351261

1262+
12361263
def maybe_find_leaked_tracers(trace: Trace) -> list[Tracer]:
12371264
"""Find the leaked tracers holding a reference to the Trace
12381265
"""

jax/_src/dispatch.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import atexit
1919
from collections.abc import Sequence
20-
import contextlib
2120
import dataclasses
2221
import enum
2322
from functools import partial
@@ -170,22 +169,32 @@ def wait_for_tokens():
170169
runtime_tokens.block_until_ready()
171170

172171

173-
@contextlib.contextmanager
174-
def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
175-
if _on_exit:
176-
yield
177-
else:
178-
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
179-
start_time = time.time()
180-
yield
172+
class LogElapsedTimeContextManager:
173+
__slots__ = ['fmt', 'fun_name', 'event', 'start_time']
174+
175+
def __init__(self, fmt: str, fun_name: str, event: str | None = None):
176+
self.fmt = fmt
177+
self.fun_name = fun_name
178+
self.event = event
179+
180+
def __enter__(self):
181+
self.start_time = time.time()
182+
183+
def __exit__(self, exc_type, exc_value, traceback):
184+
if _on_exit:
185+
return
186+
181187
end_time = time.time()
182-
elapsed_time = end_time - start_time
188+
elapsed_time = end_time - self.start_time
189+
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
183190
if logger.isEnabledFor(log_priority):
184-
logger.log(log_priority, fmt.format(
185-
fun_name=fun_name, elapsed_time=elapsed_time))
186-
if event is not None:
187-
record_event_duration_secs(event, elapsed_time)
188-
record_event_time_span(event, start_time, end_time)
191+
logger.log(log_priority, self.fmt.format(
192+
fun_name=self.fun_name, elapsed_time=elapsed_time))
193+
if self.event is not None:
194+
record_event_duration_secs(self.event, elapsed_time)
195+
record_event_time_span(self.event, self.start_time, end_time)
196+
197+
log_elapsed_time = LogElapsedTimeContextManager
189198

190199

191200
def should_tuple_args(num_args: int, platform: str) -> bool:

jax/_src/mesh.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,20 @@ def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
549549
def _raise_value_error(name):
550550
raise ValueError(f"AbstractMesh does not implement {name}")
551551

552+
class SetAbstractMeshContextManager:
553+
__slots__ = ['mesh', 'prev']
554+
555+
def __init__(self, mesh: AbstractMesh):
556+
self.mesh = mesh
557+
558+
def __enter__(self):
559+
self.prev = jax_config.abstract_mesh_context_manager.swap_local(self.mesh)
560+
561+
def __exit__(self, exc_type, exc_value, traceback):
562+
jax_config.abstract_mesh_context_manager.set_local(self.prev)
563+
564+
set_abstract_mesh = SetAbstractMeshContextManager
552565

553-
@contextlib.contextmanager
554-
def set_abstract_mesh(mesh: AbstractMesh):
555-
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
556-
try:
557-
yield
558-
finally:
559-
jax_config.abstract_mesh_context_manager.set_local(prev_val)
560566

561567
empty_abstract_mesh = AbstractMesh(())
562568

0 commit comments

Comments
 (0)