Skip to content

Commit 0f4677b

Browse files
Merge pull request jax-ml#25713 from jakevdp:debug-printoptions
PiperOrigin-RevId: 711671926
2 parents 57b2154 + 3306063 commit 0f4677b

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

jax/_src/debugging.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ def check_unused_args(self, used_args, args, kwargs):
300300

301301
formatter = _DebugPrintFormatChecker()
302302

303-
def _format_print_callback(fmt: str, *args, **kwargs):
304-
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
303+
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
304+
with np.printoptions(**np_printoptions):
305+
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
305306

306307
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
307308
"""Prints values and works in staged out JAX functions.
@@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, **kwargs):
338339
# Check that we provide the correct arguments to be formatted.
339340
formatter.format(fmt, *args, **kwargs)
340341

341-
debug_callback(functools.partial(_format_print_callback, fmt), *args,
342-
**kwargs, ordered=ordered)
342+
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
343+
*args, **kwargs, ordered=ordered)
343344

344345

345346
# Sharding visualization

tests/debugging_primitives_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,29 @@ def f(x):
219219
[ 1 2 3 4 5 6 7 8 9 10 12 13 14]]
220220
"""))
221221

222+
def test_debug_print_respects_numpy_printoptions(self):
223+
def f(x):
224+
with np.printoptions(precision=2, suppress=True):
225+
jax.debug.print("{}", x)
226+
x = np.array([1.2345, 2.3456, 1E-7])
227+
228+
# Default numpy print options:
229+
with jtu.capture_stdout() as output:
230+
jax.debug.print("{}", x)
231+
self.assertEqual(output(), "[1.2345e+00 2.3456e+00 1.0000e-07]\n")
232+
233+
# Modified print options without JIT:
234+
with jtu.capture_stdout() as output:
235+
f(x)
236+
jax.effects_barrier()
237+
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
238+
239+
# Modified print options with JIT:
240+
with jtu.capture_stdout() as output:
241+
jax.jit(f)(x)
242+
jax.effects_barrier()
243+
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
244+
222245

223246
class DebugPrintTransformationTest(jtu.JaxTestCase):
224247

0 commit comments

Comments
 (0)