@@ -300,8 +300,9 @@ def check_unused_args(self, used_args, args, kwargs):
300300
301301formatter = _DebugPrintFormatChecker ()
302302
303- def _format_print_callback (fmt : str , * args , ** kwargs ):
304- sys .stdout .write (fmt .format (* args , ** kwargs ) + "\n " )
303+ def _format_print_callback (fmt : str , np_printoptions , * args , ** kwargs ):
304+ with np .printoptions (** np_printoptions ):
305+ sys .stdout .write (fmt .format (* args , ** kwargs ) + "\n " )
305306
306307def debug_print (fmt : str , * args , ordered : bool = False , ** kwargs ) -> None :
307308 """Prints values and works in staged out JAX functions.
@@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, **kwargs):
338339 # Check that we provide the correct arguments to be formatted.
339340 formatter .format (fmt , * args , ** kwargs )
340341
341- debug_callback (functools .partial (_format_print_callback , fmt ), * args ,
342- ** kwargs , ordered = ordered )
342+ debug_callback (functools .partial (_format_print_callback , fmt , np . get_printoptions ()) ,
343+ * args , * *kwargs , ordered = ordered )
343344
344345
345346# Sharding visualization
0 commit comments