Skip to content

Commit 62e66b6

Browse files
committed
Don't monkey-patch functions in test_utils to count events for tests.
This has two problems: * it's not thread-safe, which will become problematic if we run tests with thread-parallelism. * it's not very maintainable. Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
1 parent 3630756 commit 62e66b6

25 files changed

+213
-321
lines changed

examples/ffi/tests/cpu_examples_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def test_array_attr_jit_cache(self):
3737
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
3838
with jtu.count_jit_and_pmap_lowerings() as count:
3939
jit_array_attr(5)
40-
self.assertEqual(count[0], 1) # compiles once the first time
40+
self.assertEqual(count(), 1) # compiles once the first time
4141
with jtu.count_jit_and_pmap_lowerings() as count:
4242
jit_array_attr(5)
43-
self.assertEqual(count[0], 0) # cache hit
43+
self.assertEqual(count(), 0) # cache hit
4444

4545
def test_array_attr_no_jit(self):
4646
with jax.disable_jit():

jax/_src/array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src import dtypes
3434
from jax._src import errors
3535
from jax._src import profiler
36+
from jax._src import util
3637
from jax._src import xla_bridge
3738
from jax._src.interpreters import mlir
3839
from jax._src.interpreters import pxla
@@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
11311132

11321133

11331134
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
1135+
util.test_event("_array_shard_arg")
11341136
results = []
11351137
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
11361138
batch_cs = []
@@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
11681170
results.append(
11691171
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
11701172

1173+
util.test_event("batched_copy_array")
11711174
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
11721175
batch_xs, batch_devs, batch_shardings, batch_cs)
11731176
for i, copy_out in safe_zip(batch_indices, copy_outs):

jax/_src/dispatch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):
9494

9595
@util.cache()
9696
def xla_primitive_callable(prim: core.Primitive, **params):
97+
util.test_event("xla_primitive_callable_cache_miss")
9798
def prim_fun(*args):
9899
with config.eager_constant_folding(False):
99100
return prim.bind(*args, **params)

jax/_src/interpreters/mlir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ def lower_jaxpr_to_module(
10851085
Handles the quirks of the argument/return value passing conventions of the
10861086
runtime.
10871087
"""
1088+
util.test_event("lower_jaxpr_to_module")
10881089
platforms = tuple(map(xb.canonicalize_platform, platforms))
10891090

10901091
in_avals = (jaxpr.in_avals if arg_shardings is None else
@@ -1378,6 +1379,7 @@ def lower_jaxpr_to_fun(
13781379
Returns:
13791380
MLIR func op
13801381
"""
1382+
util.test_event("lower_jaxpr_to_fun", name)
13811383

13821384
# The first dimension variable may be the platform index
13831385
num_dim_vars = len(ctx.shape_poly_state.dim_vars)

jax/_src/interpreters/pxla.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,20 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
231231
def batched_device_put(aval: core.ShapedArray,
232232
sharding: JSharding, xs: Sequence[Any],
233233
devices: Sequence[jax.Device], committed: bool = True):
234-
from jax._src import array
235-
236-
bufs = [x for x, d in safe_zip(xs, devices)
237-
if (isinstance(x, array.ArrayImpl) and
238-
dispatch.is_single_device_sharding(x.sharding) and
239-
x.devices() == {d})]
240-
if len(bufs) == len(xs):
241-
return array.ArrayImpl(
242-
aval, sharding, bufs, committed=committed, _skip_checks=True)
243-
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
234+
util.test_event("batched_device_put_start")
235+
try:
236+
from jax._src import array
237+
238+
bufs = [x for x, d in safe_zip(xs, devices)
239+
if (isinstance(x, array.ArrayImpl) and
240+
dispatch.is_single_device_sharding(x.sharding) and
241+
x.devices() == {d})]
242+
if len(bufs) == len(xs):
243+
return array.ArrayImpl(
244+
aval, sharding, bufs, committed=committed, _skip_checks=True)
245+
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
246+
finally:
247+
util.test_event("batched_device_put_end")
244248

245249
def _shard_aval(size, axis: int, aval):
246250
try:
@@ -2850,6 +2854,7 @@ def from_hlo(name: str,
28502854
mesh = i.mesh
28512855
break
28522856

2857+
util.test_event("pxla_cached_compilation")
28532858
xla_executable = _cached_compilation(
28542859
hlo, name, mesh, spmd_lowering,
28552860
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,

jax/_src/pjit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def _infer_params_impl(
549549
kwargs: dict[str, Any],
550550
in_avals: tuple[core.AbstractValue, ...] | None,
551551
) -> tuple[PjitParams, list[Any]]:
552+
util.test_event("pjit._infer_params_impl", fun)
552553
have_kwargs = bool(kwargs)
553554
if have_kwargs and ji.user_specified_in_shardings:
554555
raise ValueError(
@@ -1297,6 +1298,7 @@ def _create_pjit_jaxpr(
12971298
ignored_inline: IgnoreKey
12981299
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
12991300
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
1301+
util.test_event("create_pjit_jaxpr")
13001302
del ignored_inline # just for explain_cache_miss
13011303
if config.no_tracing.value:
13021304
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
@@ -1784,6 +1786,7 @@ def _pjit_lower(
17841786
lowering_platforms: tuple[str, ...] | None,
17851787
lowering_parameters: mlir.LoweringParameters,
17861788
pgle_profiler: profiler.PGLEProfiler | None):
1789+
util.test_event("pjit_lower")
17871790
if config.sharding_in_types.value:
17881791
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
17891792
else:

jax/_src/stages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def output_layouts(self):
533533

534534
@staticmethod
535535
def call(*args, **kwargs):
536+
util.test_event("stages_compiled_call")
536537
# This is because `__call__` passes in `self._params` as the first argument.
537538
# Instead of making the call signature `call(params, *args, **kwargs)`
538539
# extract it from args because `params` can be passed as a kwarg by users

0 commit comments

Comments
 (0)