Skip to content

Commit 2259a13

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Use capture_stderr instead of packing sys.stderr
This method is more reliable as it can capture stderr writes from the whole process and not only those coming from Python. I noticed this test was failing sporadically. I'm not sure if this will fix it, but at least it might tell us why (thanks to e.g. the weakref check). PiperOrigin-RevId: 707557909
1 parent 6edfe9e commit 2259a13

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

tests/garbage_collection_guard_test.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
"""Tests for garbage allocation guard."""
1515

1616
import gc
17-
import io
18-
from unittest import mock
17+
import weakref
1918

2019
from absl.testing import absltest
2120
import jax
@@ -40,42 +39,36 @@ def _create_array_cycle():
4039
n2 = GarbageCollectionGuardTestNodeHelper(jax.jit(lambda: jnp.zeros((2, 2)))())
4140
n1.next = n2
4241
n2.next = n1
42+
return weakref.ref(n1)
4343

4444

4545
class GarbageCollectionGuardTest(jtu.JaxTestCase):
4646

4747
def test_gced_array_is_not_logged_by_default(self):
4848
# Create a reference cycle of two jax.Arrays.
49-
_create_array_cycle()
50-
51-
# Use mock_stderr to be able to inspect stderr.
52-
mock_stderr = io.StringIO()
53-
with mock.patch("sys.stderr", mock_stderr):
54-
# Trigger a garbage collection, which will garbage collect the arrays
55-
# in the cycle.
49+
ref = _create_array_cycle()
50+
with jtu.capture_stderr() as stderr:
51+
self.assertIsNotNone(ref()) # Cycle still alive.
5652
gc.collect()
53+
self.assertIsNone(ref()) # Cycle collected.
5754
# Check that no error message is logged because
5855
# `array_garbage_collection_guard` defaults to `allow`.
5956
self.assertNotIn(
60-
"`jax.Array` was deleted by the Python garbage collector",
61-
mock_stderr.getvalue(),
57+
"`jax.Array` was deleted by the Python garbage collector", stderr(),
6258
)
6359

6460
def test_gced_array_is_logged(self):
65-
# Use mock_stderr to be able to inspect stderr.
66-
mock_stderr = io.StringIO()
67-
6861
with config.array_garbage_collection_guard("log"):
69-
# Create a reference cycle of two jax.Arrays.
70-
_create_array_cycle()
71-
with mock.patch("sys.stderr", mock_stderr):
62+
with jtu.capture_stderr() as stderr:
63+
# Create a reference cycle of two jax.Arrays.
64+
ref = _create_array_cycle()
65+
self.assertIsNotNone(ref()) # Cycle still alive.
7266
gc.collect()
73-
67+
self.assertIsNone(ref()) # Cycle collected.
7468
# Verify that an error message is logged because two jax.Arrays were garbage
7569
# collected.
7670
self.assertIn(
77-
"`jax.Array` was deleted by the Python garbage collector",
78-
mock_stderr.getvalue(),
71+
"`jax.Array` was deleted by the Python garbage collector", stderr()
7972
)
8073

8174

0 commit comments

Comments
 (0)