11import dataclasses
2- import functools as ft
3- import sys
4- import types
5- from collections .abc import Callable , Sequence
6- from typing import Any , Optional , Union
2+ from collections .abc import Callable
3+ from typing import Any
74
85import jax
9- import jax ._src .pretty_printer as pp
106import jax .tree_util as jtu
11- import numpy as np
12- from jaxtyping import PyTree
13-
14- from ._doc_utils import WithRepr
15-
16-
17- Dataclass = Any
18- NamedTuple = Any # workaround typeguard bug
19- PrettyPrintAble = PyTree
20-
21- # Re-export
22- text = pp .text
23-
24-
25- _comma_sep = pp .concat ([pp .text ("," ), pp .brk ()])
26-
27-
28- def bracketed (
29- name : Optional [pp .Doc ],
30- indent : int ,
31- objs : Sequence [pp .Doc ],
32- lbracket : str ,
33- rbracket : str ,
34- ) -> pp .Doc :
35- nested = pp .concat (
36- [
37- pp .nest (indent , pp .concat ([pp .brk ("" ), pp .join (_comma_sep , objs )])),
38- pp .brk ("" ),
39- ]
40- )
41- concated = []
42- if name is not None :
43- concated .append (name )
44- concated .extend ([pp .text (lbracket ), nested , pp .text (rbracket )])
45- return pp .group (pp .concat (concated ))
46-
47-
48- def named_objs (pairs , ** kwargs ):
49- return [
50- pp .concat ([pp .text (key + "=" ), tree_pp (value , ** kwargs )])
51- for key , value in pairs
52- ]
53-
54-
55- def _pformat_list (obj : list , ** kwargs ) -> pp .Doc :
56- return bracketed (
57- name = None ,
58- indent = kwargs ["indent" ],
59- objs = [tree_pp (x , ** kwargs ) for x in obj ],
60- lbracket = "[" ,
61- rbracket = "]" ,
62- )
63-
64-
65- def _pformat_tuple (obj : tuple , ** kwargs ) -> pp .Doc :
66- if len (obj ) == 1 :
67- objs = [pp .concat ([tree_pp (obj [0 ], ** kwargs ), pp .text ("," )])]
68- else :
69- objs = [tree_pp (x , ** kwargs ) for x in obj ]
70- return bracketed (
71- name = None , indent = kwargs ["indent" ], objs = objs , lbracket = "(" , rbracket = ")"
72- )
73-
74-
75- def _pformat_namedtuple (obj : NamedTuple , ** kwargs ) -> pp .Doc :
76- objs = named_objs ([(name , getattr (obj , name )) for name in obj ._fields ], ** kwargs )
77- return bracketed (
78- name = pp .text (obj .__class__ .__name__ ),
79- indent = kwargs ["indent" ],
80- objs = objs ,
81- lbracket = "(" ,
82- rbracket = ")" ,
83- )
84-
85-
86- def _dict_entry (key : Any , value : Any , ** kwargs ) -> pp .Doc :
87- return pp .concat (
88- [tree_pp (key , ** kwargs ), pp .text (":" ), pp .brk (), tree_pp (value , ** kwargs )]
89- )
90-
91-
92- def _pformat_dict (obj : dict , ** kwargs ) -> pp .Doc :
93- objs = [_dict_entry (key , value , ** kwargs ) for key , value in obj .items ()]
94- return bracketed (
95- name = None ,
96- indent = kwargs ["indent" ],
97- objs = objs ,
98- lbracket = "{" ,
99- rbracket = "}" ,
100- )
101-
102-
103- def pformat_short_array_text (shape : tuple [int , ...], dtype : str ) -> str :
104- short_dtype = (
105- dtype .replace ("float" , "f" )
106- .replace ("uint" , "u" )
107- .replace ("int" , "i" )
108- .replace ("complex" , "c" )
109- )
110- short_shape = "," .join (map (str , shape ))
111- return f"{ short_dtype } [{ short_shape } ]"
112-
113-
114- def _pformat_short_array (
115- shape : tuple [int , ...], dtype : str , kind : Optional [str ]
116- ) -> pp .Doc :
117- out = pformat_short_array_text (shape , dtype )
118- if kind is not None :
119- out = out + f"({ kind } )"
120- return pp .text (out )
121-
122-
123- def _pformat_array (
124- obj : Union [
125- jax .Array ,
126- jax .ShapeDtypeStruct ,
127- np .ndarray ,
128- "torch.Tensor" , # pyright: ignore # noqa: F821
129- ],
130- ** kwargs ,
131- ) -> pp .Doc :
132- short_arrays = kwargs ["short_arrays" ]
133- if short_arrays :
134- # Support torch here for the sake of jaxtyping's pretty-printed error messages.
135- if "torch" in sys .modules and isinstance (obj , sys .modules ["torch" ].Tensor ):
136- dtype = repr (obj .dtype ).split ("." )[1 ]
137- kind = "torch"
138- else :
139- dtype = obj .dtype .name
140- if isinstance (obj , (jax .Array , jax .ShapeDtypeStruct )):
141- # Added in JAX 0.4.32 to `ShapeDtypeStruct`
142- if getattr (obj , "weak_type" , False ):
143- dtype = f"weak_{ dtype } "
144- kind = None
145- elif isinstance (obj , np .ndarray ):
146- kind = "numpy"
147- else :
148- kind = "unknown"
149- return _pformat_short_array (obj .shape , dtype , kind )
150- else :
151- return pp .text (repr (obj ))
152-
153-
154- def _pformat_function (obj : types .FunctionType , ** kwargs ) -> pp .Doc :
155- if kwargs .get ("wrapped" , False ):
156- fn = "wrapped function"
157- else :
158- fn = "function"
159- return pp .text (f"<{ fn } { obj .__name__ } >" )
160-
161-
162- def _pformat_dataclass (obj , ** kwargs ) -> pp .Doc :
163- # <uninitialised> can happen when typechecking an `eqx.Module`'s `__init__` method
164- # with beartype, and printing args using pytest. We haven't yet actually assigned
165- # values to the module so the repr fails.
166- objs = named_objs (
167- [
168- (field .name , getattr (obj , field .name , WithRepr (None , "<uninitialised>" )))
169- for field in dataclasses .fields (obj )
170- if field .repr
171- ],
172- ** kwargs ,
173- )
174- return bracketed (
175- name = pp .text (obj .__class__ .__name__ ),
176- indent = kwargs ["indent" ],
177- objs = objs ,
178- lbracket = "(" ,
179- rbracket = ")" ,
180- )
7+ import wadler_lindig as wl
1818
1829
18310@dataclasses .dataclass
@@ -192,63 +19,18 @@ class _Partial:
19219_Partial .__module__ = jtu .Partial .__module__
19320
19421
195- def tree_pp (obj : PrettyPrintAble , ** kwargs ) -> pp .Doc :
196- follow_wrapped = kwargs ["follow_wrapped" ]
197- truncate_leaf = kwargs ["truncate_leaf" ]
198- if truncate_leaf (obj ):
199- return pp .text (f"{ type (obj ).__name__ } (...)" )
200- if hasattr (obj , "__tree_pp__" ):
201- custom_pp = obj .__tree_pp__ (** kwargs )
202- if custom_pp is not NotImplemented :
203- return pp .group (custom_pp )
204- if dataclasses .is_dataclass (obj ) and not isinstance (obj , type ):
205- return _pformat_dataclass (obj , ** kwargs )
206- elif isinstance (obj , list ):
207- return _pformat_list (obj , ** kwargs )
208- elif isinstance (obj , dict ):
209- return _pformat_dict (obj , ** kwargs )
210- elif isinstance (obj , tuple ):
211- if hasattr (obj , "_fields" ):
212- return _pformat_namedtuple (obj , ** kwargs )
213- else :
214- return _pformat_tuple (obj , ** kwargs )
215- elif (
216- isinstance (obj , (np .ndarray , jax .Array ))
217- or ("torch" in sys .modules and isinstance (obj , sys .modules ["torch" ].Tensor ))
218- or kwargs .get ("struct_as_array" , False )
219- and isinstance (obj , jax .ShapeDtypeStruct )
220- ):
221- return _pformat_array (obj , ** kwargs )
222- elif isinstance (obj , (jax .custom_jvp , jax .custom_vjp )):
223- return tree_pp (obj .__wrapped__ , ** kwargs )
224- elif hasattr (obj , "__wrapped__" ) and follow_wrapped :
225- kwargs ["wrapped" ] = True
226- return tree_pp (obj .__wrapped__ , ** kwargs ) # pyright: ignore
227- elif isinstance (obj , jtu .Partial ) and follow_wrapped :
228- obj = _Partial (obj .func , obj .args , obj .keywords )
229- return _pformat_dataclass (obj , ** kwargs )
230- elif isinstance (obj , ft .partial ) and follow_wrapped :
231- kwargs ["wrapped" ] = True
232- return tree_pp (obj .func , ** kwargs )
233- elif isinstance (obj , types .FunctionType ):
234- return _pformat_function (obj , ** kwargs )
235- else : # int, str, float, complex, bool, etc.
236- return pp .text (repr (obj ))
237-
238-
23922def _false (_ ):
24023 return False
24124
24225
24326def tree_pformat (
244- pytree : PrettyPrintAble ,
27+ pytree : Any ,
24528 * ,
24629 width : int = 80 ,
24730 indent : int = 2 ,
24831 short_arrays : bool = True ,
24932 struct_as_array : bool = False ,
250- follow_wrapped : bool = True ,
251- truncate_leaf : Callable [[PrettyPrintAble ], bool ] = _false ,
33+ truncate_leaf : Callable [[Any ], bool ] = _false ,
25234) -> str :
25335 """Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.
25436
@@ -257,25 +39,40 @@ def tree_pformat(
25739 As [`equinox.tree_pprint`][], but returns the string instead of printing it.
25840 """
25941
260- return tree_pp (
261- pytree ,
262- indent = indent ,
263- short_arrays = short_arrays ,
264- struct_as_array = struct_as_array ,
265- follow_wrapped = follow_wrapped ,
266- truncate_leaf = truncate_leaf ,
267- ).format (width = width )
42+ def custom (obj ):
43+ if truncate_leaf (obj ):
44+ return wl .TextDoc (f"{ type (obj ).__name__ } (...)" )
45+
46+ if short_arrays :
47+ if isinstance (obj , jax .Array ) or (
48+ struct_as_array and isinstance (obj , jax .ShapeDtypeStruct )
49+ ):
50+ dtype = obj .dtype .name
51+ # Added in JAX 0.4.32 to `ShapeDtypeStruct`
52+ if getattr (obj , "weak_type" , False ):
53+ dtype = f"weak_{ dtype } "
54+ return wl .array_summary (obj .shape , dtype , kind = None )
55+
56+ if isinstance (obj , (jax .custom_jvp , jax .custom_vjp )):
57+ return wl .pdoc (obj .__wrapped__ )
58+
59+ if isinstance (obj , jtu .Partial ):
60+ obj = _Partial (obj .func , obj .args , obj .keywords )
61+ return wl .pdoc (obj )
62+
63+ return wl .pformat (
64+ pytree , width = width , indent = indent , short_arrays = short_arrays , custom = custom
65+ )
26866
26967
27068def tree_pprint (
271- pytree : PrettyPrintAble ,
69+ pytree : Any ,
27270 * ,
27371 width : int = 80 ,
27472 indent : int = 2 ,
27573 short_arrays : bool = True ,
27674 struct_as_array : bool = False ,
277- follow_wrapped : bool = True ,
278- truncate_leaf : Callable [[PrettyPrintAble ], bool ] = _false ,
75+ truncate_leaf : Callable [[Any ], bool ] = _false ,
27976) -> None :
28077 """Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.
28178
@@ -295,7 +92,6 @@ def tree_pprint(
29592 - `indent`: The amount of indentation each nesting level.
29693 - `short_arrays`: Toggles the abbreviation of JAX arrays.
29794 - `struct_as_array`: Whether to treat `jax.ShapeDtypeStruct`s as arrays.
298- - `follow_wrapped`: Whether to unwrap `functools.partial` and `functools.wraps`.
29995 - `truncate_leaf`: A function `Any -> bool`. Applied to all nodes in the PyTree;
30096 all truthy nodes will be truncated to just `f"{type(node).__name__}(...)"`.
30197
@@ -310,7 +106,6 @@ def tree_pprint(
310106 indent = indent ,
311107 short_arrays = short_arrays ,
312108 struct_as_array = struct_as_array ,
313- follow_wrapped = follow_wrapped ,
314109 truncate_leaf = truncate_leaf ,
315110 )
316111 )
0 commit comments