Skip to content

Commit f7ce3be

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Add ignore warnings for tests failing on PmapSharding deprecation warnings.
See: https://github.com/jax-ml/jax/actions/runs/19429155535/job/55583665289 PiperOrigin-RevId: 833346834
1 parent 31de82b commit f7ce3be

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

tests/multiprocess/host_callback_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def callback_func(axis_index, x, sum_global):
111111
10 * local_device_idx + 11], dtype=np.int32),
112112
expected_sum_global)])
113113

114+
@jtu.ignore_warning(category=DeprecationWarning)
114115
def test_io_callback_pjit(self):
115116
devices = np.array(sorted_devices()).reshape(
116117
(NR_PROCESSES, NR_LOCAL_DEVICES))

tests/multiprocess/pjit_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def create_2d_non_contiguous_mesh2():
9090
# TODO(apaszke): Test with mesh that has host-tiled axes (especially nesting!)
9191
class PJitTestMultiHost(jt_multiprocess.MultiProcessTest):
9292

93+
@jtu.ignore_warning(category=DeprecationWarning)
9394
def testLocalInputsWithJaxArray(self):
9495
# Note that this is too small to shard over the global mesh, but fine for
9596
# the local mesh and so should be accepted.

0 commit comments

Comments
 (0)