Skip to content

Commit 6297f52

Browse files
PyTrees now use Wadler-Lindig to handle their naming and pretty-printing. In particualr this is useful for better documentation.
1 parent dd2394e commit 6297f52

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

jaxtyping/_pytree_type.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Generic, TypeVar
2323

2424
import jax.tree_util as jtu
25+
import wadler_lindig as wl
2526

2627
from ._errors import AnnotationError
2728
from ._storage import (
@@ -35,15 +36,25 @@
3536

3637

3738
_T = TypeVar("_T")
39+
_S = TypeVar("_S")
3840

3941

40-
class _FakePyTree(Generic[_T]):
42+
class _FakePyTree1(Generic[_T]):
4143
pass
4244

4345

44-
_FakePyTree.__name__ = "PyTree"
45-
_FakePyTree.__qualname__ = "PyTree"
46-
_FakePyTree.__module__ = "builtins"
46+
_FakePyTree1.__name__ = "PyTree"
47+
_FakePyTree1.__qualname__ = "PyTree"
48+
_FakePyTree1.__module__ = "builtins"
49+
50+
51+
class _FakePyTree2(Generic[_T, _S]):
52+
pass
53+
54+
55+
_FakePyTree2.__name__ = "PyTree"
56+
_FakePyTree2.__qualname__ = "PyTree"
57+
_FakePyTree2.__module__ = "builtins"
4758

4859

4960
class _MetaPyTree(type):
@@ -226,7 +237,15 @@ class X(PyTree):
226237
"regular Python, i.e. a valid variable name.)\n"
227238
f"Got piece '{piece}' in overall structure '{X.structure}'."
228239
)
229-
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
240+
241+
class Y:
242+
pass
243+
244+
Y.__module__ = "builtins"
245+
Y.__name__ = repr(X.structure)
246+
Y.__qualname__ = repr(X.structure)
247+
name = wl.pformat(_FakePyTree2[X.leaftype, Y], width=9999)
248+
del Y
230249
else:
231250
raise ValueError(
232251
"The subscript `foo` in `jaxtyping.PyTree[foo]` must either be a "
@@ -235,7 +254,7 @@ class X(PyTree):
235254
f"{len(item)}."
236255
)
237256
else:
238-
name = str(_FakePyTree[item])
257+
name = wl.pformat(_FakePyTree1[item], width=9999)
239258

240259
class X(PyTree):
241260
leaftype = item
@@ -249,6 +268,22 @@ class X(PyTree):
249268
X.__module__ = "jaxtyping"
250269
return X
251270

271+
def __pdoc__(self, **kwargs):
272+
if self is PyTree:
273+
return wl.TextDoc("PyTree")
274+
else:
275+
indent = kwargs["indent"]
276+
docs = [wl.pdoc(self.leaftype, **kwargs)]
277+
if self.structure is not None:
278+
docs.append(wl.pdoc(self.structure, **kwargs))
279+
return wl.bracketed(
280+
begin=wl.TextDoc("PyTree["),
281+
docs=docs,
282+
sep=wl.comma,
283+
end=wl.TextDoc("]"),
284+
indent=indent,
285+
)
286+
252287

253288
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
254289
# instancecheck for PyTree[foo], but subclassing

test/test_pytree.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

20+
from collections.abc import Callable
2021
from typing import NamedTuple, Tuple, Union
2122

2223
import equinox as eqx
2324
import jax
2425
import jax.numpy as jnp
2526
import jax.random as jr
2627
import pytest
28+
import wadler_lindig as wl
2729

2830
import jaxtyping
2931
from jaxtyping import AnnotationError, Array, Float, PyTree
@@ -341,3 +343,53 @@ def f(x: PyTree[PyTree[Float[Array, "?foo"], " S"], " T"]):
341343
x1 = jr.normal(getkey(), (2,))
342344
with pytest.raises(AnnotationError, match="ambiguous which PyTree"):
343345
f(x1)
346+
347+
348+
def test_name():
349+
assert PyTree.__name__ == "PyTree"
350+
assert PyTree[int].__name__ == "PyTree[int]"
351+
assert PyTree[int, "foo"].__name__ == "PyTree[int, 'foo']"
352+
assert PyTree[PyTree[str], "foo"].__name__ == "PyTree[PyTree[str], 'foo']"
353+
assert (
354+
PyTree[PyTree[str, "bar"], "foo"].__name__
355+
== "PyTree[PyTree[str, 'bar'], 'foo']"
356+
)
357+
assert PyTree[PyTree[str, "bar"]].__name__ == "PyTree[PyTree[str, 'bar']]"
358+
assert (
359+
PyTree[None | Callable[[PyTree[int, " T"]], str]].__name__
360+
== "PyTree[None | Callable[[PyTree[int, 'T']], str]]"
361+
)
362+
363+
364+
def test_pdoc():
365+
assert wl.pformat(PyTree) == "PyTree"
366+
assert wl.pformat(PyTree[int]) == "PyTree[int]"
367+
assert wl.pformat(PyTree[int, "foo"]) == "PyTree[int, 'foo']"
368+
assert wl.pformat(PyTree[PyTree[str], "foo"]) == "PyTree[PyTree[str], 'foo']"
369+
assert (
370+
wl.pformat(PyTree[PyTree[str, "bar"], "foo"])
371+
== "PyTree[PyTree[str, 'bar'], 'foo']"
372+
)
373+
assert wl.pformat(PyTree[PyTree[str, "bar"]]) == "PyTree[PyTree[str, 'bar']]"
374+
assert (
375+
wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]])
376+
== "PyTree[None | Callable[[PyTree[int, 'T']], str]]"
377+
)
378+
expected = """
379+
PyTree[
380+
None
381+
| Callable[
382+
[
383+
PyTree[
384+
int,
385+
'T'
386+
]
387+
],
388+
str
389+
]
390+
]
391+
""".strip()
392+
assert (
393+
wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip()
394+
== expected
395+
)

0 commit comments

Comments
 (0)