Skip to content

Commit 5765e39

Browse files
Now using wadler_lindig pprint library.
1 parent 90626ee commit 5765e39

File tree

10 files changed

+68
-289
lines changed

10 files changed

+68
-289
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ repos:
2020
optax,
2121
pytest,
2222
typing_extensions,
23+
wadler_lindig,
2324
]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ _Coming from [Flax](https://github.com/google/flax) or [Haiku](https://github.co
1919
pip install equinox
2020
```
2121

22-
Requires Python 3.9+ and JAX 0.4.13+.
22+
Requires Python 3.10+.
2323

2424
## Documentation
2525

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ _Coming from [Flax](https://github.com/google/flax) or [Haiku](https://github.co
1919
pip install equinox
2020
```
2121

22-
Requires Python 3.9+ and JAX 0.4.13+.
22+
Requires Python 3.10+.
2323

2424
## Quick example
2525

equinox/_pretty_print.py

Lines changed: 31 additions & 236 deletions
Original file line numberDiff line numberDiff line change
@@ -1,183 +1,10 @@
11
import dataclasses
2-
import functools as ft
3-
import sys
4-
import types
5-
from collections.abc import Callable, Sequence
6-
from typing import Any, Optional, Union
2+
from collections.abc import Callable
3+
from typing import Any
74

85
import jax
9-
import jax._src.pretty_printer as pp
106
import jax.tree_util as jtu
11-
import numpy as np
12-
from jaxtyping import PyTree
13-
14-
from ._doc_utils import WithRepr
15-
16-
17-
Dataclass = Any
18-
NamedTuple = Any # workaround typeguard bug
19-
PrettyPrintAble = PyTree
20-
21-
# Re-export
22-
text = pp.text
23-
24-
25-
_comma_sep = pp.concat([pp.text(","), pp.brk()])
26-
27-
28-
def bracketed(
29-
name: Optional[pp.Doc],
30-
indent: int,
31-
objs: Sequence[pp.Doc],
32-
lbracket: str,
33-
rbracket: str,
34-
) -> pp.Doc:
35-
nested = pp.concat(
36-
[
37-
pp.nest(indent, pp.concat([pp.brk(""), pp.join(_comma_sep, objs)])),
38-
pp.brk(""),
39-
]
40-
)
41-
concated = []
42-
if name is not None:
43-
concated.append(name)
44-
concated.extend([pp.text(lbracket), nested, pp.text(rbracket)])
45-
return pp.group(pp.concat(concated))
46-
47-
48-
def named_objs(pairs, **kwargs):
49-
return [
50-
pp.concat([pp.text(key + "="), tree_pp(value, **kwargs)])
51-
for key, value in pairs
52-
]
53-
54-
55-
def _pformat_list(obj: list, **kwargs) -> pp.Doc:
56-
return bracketed(
57-
name=None,
58-
indent=kwargs["indent"],
59-
objs=[tree_pp(x, **kwargs) for x in obj],
60-
lbracket="[",
61-
rbracket="]",
62-
)
63-
64-
65-
def _pformat_tuple(obj: tuple, **kwargs) -> pp.Doc:
66-
if len(obj) == 1:
67-
objs = [pp.concat([tree_pp(obj[0], **kwargs), pp.text(",")])]
68-
else:
69-
objs = [tree_pp(x, **kwargs) for x in obj]
70-
return bracketed(
71-
name=None, indent=kwargs["indent"], objs=objs, lbracket="(", rbracket=")"
72-
)
73-
74-
75-
def _pformat_namedtuple(obj: NamedTuple, **kwargs) -> pp.Doc:
76-
objs = named_objs([(name, getattr(obj, name)) for name in obj._fields], **kwargs)
77-
return bracketed(
78-
name=pp.text(obj.__class__.__name__),
79-
indent=kwargs["indent"],
80-
objs=objs,
81-
lbracket="(",
82-
rbracket=")",
83-
)
84-
85-
86-
def _dict_entry(key: Any, value: Any, **kwargs) -> pp.Doc:
87-
return pp.concat(
88-
[tree_pp(key, **kwargs), pp.text(":"), pp.brk(), tree_pp(value, **kwargs)]
89-
)
90-
91-
92-
def _pformat_dict(obj: dict, **kwargs) -> pp.Doc:
93-
objs = [_dict_entry(key, value, **kwargs) for key, value in obj.items()]
94-
return bracketed(
95-
name=None,
96-
indent=kwargs["indent"],
97-
objs=objs,
98-
lbracket="{",
99-
rbracket="}",
100-
)
101-
102-
103-
def pformat_short_array_text(shape: tuple[int, ...], dtype: str) -> str:
104-
short_dtype = (
105-
dtype.replace("float", "f")
106-
.replace("uint", "u")
107-
.replace("int", "i")
108-
.replace("complex", "c")
109-
)
110-
short_shape = ",".join(map(str, shape))
111-
return f"{short_dtype}[{short_shape}]"
112-
113-
114-
def _pformat_short_array(
115-
shape: tuple[int, ...], dtype: str, kind: Optional[str]
116-
) -> pp.Doc:
117-
out = pformat_short_array_text(shape, dtype)
118-
if kind is not None:
119-
out = out + f"({kind})"
120-
return pp.text(out)
121-
122-
123-
def _pformat_array(
124-
obj: Union[
125-
jax.Array,
126-
jax.ShapeDtypeStruct,
127-
np.ndarray,
128-
"torch.Tensor", # pyright: ignore # noqa: F821
129-
],
130-
**kwargs,
131-
) -> pp.Doc:
132-
short_arrays = kwargs["short_arrays"]
133-
if short_arrays:
134-
# Support torch here for the sake of jaxtyping's pretty-printed error messages.
135-
if "torch" in sys.modules and isinstance(obj, sys.modules["torch"].Tensor):
136-
dtype = repr(obj.dtype).split(".")[1]
137-
kind = "torch"
138-
else:
139-
dtype = obj.dtype.name
140-
if isinstance(obj, (jax.Array, jax.ShapeDtypeStruct)):
141-
# Added in JAX 0.4.32 to `ShapeDtypeStruct`
142-
if getattr(obj, "weak_type", False):
143-
dtype = f"weak_{dtype}"
144-
kind = None
145-
elif isinstance(obj, np.ndarray):
146-
kind = "numpy"
147-
else:
148-
kind = "unknown"
149-
return _pformat_short_array(obj.shape, dtype, kind)
150-
else:
151-
return pp.text(repr(obj))
152-
153-
154-
def _pformat_function(obj: types.FunctionType, **kwargs) -> pp.Doc:
155-
if kwargs.get("wrapped", False):
156-
fn = "wrapped function"
157-
else:
158-
fn = "function"
159-
return pp.text(f"<{fn} {obj.__name__}>")
160-
161-
162-
def _pformat_dataclass(obj, **kwargs) -> pp.Doc:
163-
# <uninitialised> can happen when typechecking an `eqx.Module`'s `__init__` method
164-
# with beartype, and printing args using pytest. We haven't yet actually assigned
165-
# values to the module so the repr fails.
166-
objs = named_objs(
167-
[
168-
(field.name, getattr(obj, field.name, WithRepr(None, "<uninitialised>")))
169-
for field in dataclasses.fields(obj)
170-
if field.repr
171-
],
172-
**kwargs,
173-
)
174-
return bracketed(
175-
name=pp.text(obj.__class__.__name__),
176-
indent=kwargs["indent"],
177-
objs=objs,
178-
lbracket="(",
179-
rbracket=")",
180-
)
7+
import wadler_lindig as wl
1818

1829

18310
@dataclasses.dataclass
@@ -192,63 +19,18 @@ class _Partial:
19219
_Partial.__module__ = jtu.Partial.__module__
19320

19421

195-
def tree_pp(obj: PrettyPrintAble, **kwargs) -> pp.Doc:
196-
follow_wrapped = kwargs["follow_wrapped"]
197-
truncate_leaf = kwargs["truncate_leaf"]
198-
if truncate_leaf(obj):
199-
return pp.text(f"{type(obj).__name__}(...)")
200-
if hasattr(obj, "__tree_pp__"):
201-
custom_pp = obj.__tree_pp__(**kwargs)
202-
if custom_pp is not NotImplemented:
203-
return pp.group(custom_pp)
204-
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
205-
return _pformat_dataclass(obj, **kwargs)
206-
elif isinstance(obj, list):
207-
return _pformat_list(obj, **kwargs)
208-
elif isinstance(obj, dict):
209-
return _pformat_dict(obj, **kwargs)
210-
elif isinstance(obj, tuple):
211-
if hasattr(obj, "_fields"):
212-
return _pformat_namedtuple(obj, **kwargs)
213-
else:
214-
return _pformat_tuple(obj, **kwargs)
215-
elif (
216-
isinstance(obj, (np.ndarray, jax.Array))
217-
or ("torch" in sys.modules and isinstance(obj, sys.modules["torch"].Tensor))
218-
or kwargs.get("struct_as_array", False)
219-
and isinstance(obj, jax.ShapeDtypeStruct)
220-
):
221-
return _pformat_array(obj, **kwargs)
222-
elif isinstance(obj, (jax.custom_jvp, jax.custom_vjp)):
223-
return tree_pp(obj.__wrapped__, **kwargs)
224-
elif hasattr(obj, "__wrapped__") and follow_wrapped:
225-
kwargs["wrapped"] = True
226-
return tree_pp(obj.__wrapped__, **kwargs) # pyright: ignore
227-
elif isinstance(obj, jtu.Partial) and follow_wrapped:
228-
obj = _Partial(obj.func, obj.args, obj.keywords)
229-
return _pformat_dataclass(obj, **kwargs)
230-
elif isinstance(obj, ft.partial) and follow_wrapped:
231-
kwargs["wrapped"] = True
232-
return tree_pp(obj.func, **kwargs)
233-
elif isinstance(obj, types.FunctionType):
234-
return _pformat_function(obj, **kwargs)
235-
else: # int, str, float, complex, bool, etc.
236-
return pp.text(repr(obj))
237-
238-
23922
def _false(_):
24023
return False
24124

24225

24326
def tree_pformat(
244-
pytree: PrettyPrintAble,
27+
pytree: Any,
24528
*,
24629
width: int = 80,
24730
indent: int = 2,
24831
short_arrays: bool = True,
24932
struct_as_array: bool = False,
250-
follow_wrapped: bool = True,
251-
truncate_leaf: Callable[[PrettyPrintAble], bool] = _false,
33+
truncate_leaf: Callable[[Any], bool] = _false,
25234
) -> str:
25335
"""Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.
25436
@@ -257,25 +39,40 @@ def tree_pformat(
25739
As [`equinox.tree_pprint`][], but returns the string instead of printing it.
25840
"""
25941

260-
return tree_pp(
261-
pytree,
262-
indent=indent,
263-
short_arrays=short_arrays,
264-
struct_as_array=struct_as_array,
265-
follow_wrapped=follow_wrapped,
266-
truncate_leaf=truncate_leaf,
267-
).format(width=width)
42+
def custom(obj):
43+
if truncate_leaf(obj):
44+
return wl.TextDoc(f"{type(obj).__name__}(...)")
45+
46+
if short_arrays:
47+
if isinstance(obj, jax.Array) or (
48+
struct_as_array and isinstance(obj, jax.ShapeDtypeStruct)
49+
):
50+
dtype = obj.dtype.name
51+
# Added in JAX 0.4.32 to `ShapeDtypeStruct`
52+
if getattr(obj, "weak_type", False):
53+
dtype = f"weak_{dtype}"
54+
return wl.array_summary(obj.shape, dtype, kind=None)
55+
56+
if isinstance(obj, (jax.custom_jvp, jax.custom_vjp)):
57+
return wl.pdoc(obj.__wrapped__)
58+
59+
if isinstance(obj, jtu.Partial):
60+
obj = _Partial(obj.func, obj.args, obj.keywords)
61+
return wl.pdoc(obj)
62+
63+
return wl.pformat(
64+
pytree, width=width, indent=indent, short_arrays=short_arrays, custom=custom
65+
)
26866

26967

27068
def tree_pprint(
271-
pytree: PrettyPrintAble,
69+
pytree: Any,
27270
*,
27371
width: int = 80,
27472
indent: int = 2,
27573
short_arrays: bool = True,
27674
struct_as_array: bool = False,
277-
follow_wrapped: bool = True,
278-
truncate_leaf: Callable[[PrettyPrintAble], bool] = _false,
75+
truncate_leaf: Callable[[Any], bool] = _false,
27976
) -> None:
28077
"""Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.
28178
@@ -295,7 +92,6 @@ def tree_pprint(
29592
- `indent`: The amount of indentation each nesting level.
29693
- `short_arrays`: Toggles the abbreviation of JAX arrays.
29794
- `struct_as_array`: Whether to treat `jax.ShapeDtypeStruct`s as arrays.
298-
- `follow_wrapped`: Whether to unwrap `functools.partial` and `functools.wraps`.
29995
- `truncate_leaf`: A function `Any -> bool`. Applied to all nodes in the PyTree;
30096
all truthy nodes will be truncated to just `f"{type(node).__name__}(...)"`.
30197
@@ -310,7 +106,6 @@ def tree_pprint(
310106
indent=indent,
311107
short_arrays=short_arrays,
312108
struct_as_array=struct_as_array,
313-
follow_wrapped=follow_wrapped,
314109
truncate_leaf=truncate_leaf,
315110
)
316111
)

equinox/debug/_dce.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import jax.core
55
import jax.numpy as jnp
66
import jax.tree_util as jtu
7+
import wadler_lindig as wl
78
from jaxtyping import PyTree
89

910
from .._doc_utils import WithRepr
1011
from .._filters import combine, is_array, partition
11-
from .._pretty_print import pformat_short_array_text, tree_pprint
12+
from .._pretty_print import tree_pprint
1213

1314

1415
_dce_store = {}
@@ -105,7 +106,7 @@ def inspect_dce(name: Hashable = None):
105106
except KeyError:
106107
value = "<DCE'd>"
107108
else:
108-
value = pformat_short_array_text(shape, dtype)
109+
value = wl.array_summary(shape, dtype, kind=None).text
109110
new_leaves.append(WithRepr(None, value))
110111
tree = combine(jtu.tree_unflatten(treedef, new_leaves), static)
111112
print(f"Entry {i}:")

equinox/internal/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from .._eval_shape import cached_filter_eval_shape as cached_filter_eval_shape
2222
from .._misc import left_broadcast_to as left_broadcast_to
2323
from .._module import Static as Static
24-
from .._pretty_print import tree_pp as tree_pp
2524
from .._unvmap import (
2625
unvmap_all as unvmap_all,
2726
unvmap_all_p as unvmap_all_p,

equinox/internal/_str2jax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import wadler_lindig as wl
2+
13
from .._module import Module
2-
from .._pretty_print import text as pp_text
34

45

56
def str2jax(msg: str):
67
"""Creates a JAXable object whose `str(...)` is the specified string."""
78

89
class M(Module):
9-
def __tree_pp__(self, **kwargs):
10-
return pp_text(msg)
10+
def __pdoc__(self, **kwargs):
11+
del kwargs
12+
return wl.TextDoc(msg)
1113

1214
def __repr__(self):
1315
return msg

0 commit comments

Comments
 (0)