Skip to content

Commit 9ff4640

Browse files
authored
Include jaxtyping to allow for Tensor/LinearOperator typehints with sizes. (#2543)
Using the same trick in LinearOperator, sized Tensor/LinearOperator typehints are automatically included in the documentation.
1 parent 07fa68e commit 9ff4640

File tree

5 files changed

+72
-47
lines changed

5 files changed

+72
-47
lines changed

.conda/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ requirements:
1919
- python>=3.8
2020
- pytorch>=1.11
2121
- scikit-learn
22+
- jaxtyping>=0.2.9
2223
- linear_operator>=0.5.2
2324

2425
test:

docs/source/conf.py

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import sys
2020
import sphinx_rtd_theme # noqa
2121
import warnings
22-
from typing import ForwardRef
22+
23+
import jaxtyping
24+
from uncompyle6.semantics.fragments import code_deparse
2325

2426

2527
def read(*names, **kwargs):
@@ -112,7 +114,8 @@ def find_version(*file_paths):
112114
intersphinx_mapping = {
113115
"python": ("https://docs.python.org/3/", None),
114116
"torch": ("https://pytorch.org/docs/stable/", None),
115-
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None),
117+
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", "linear_operator_objects.inv"),
118+
# The local mapping here is temporary until we get a new release of linear_operator
116119
}
117120

118121
# Disable docstring inheritance
@@ -237,41 +240,79 @@ def find_version(*file_paths):
237240
]
238241

239242

240-
# -- Function to format typehints ----------------------------------------------
243+
# -- Functions to format typehints ----------------------------------------------
241244
# Adapted from
242245
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
246+
247+
248+
# Helper function
249+
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
250+
# For external classes, the format will be e.g. "torch.Tensor"
251+
# For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
252+
def _convert_internal_and_external_class_to_strings(annotation):
253+
module = annotation.__module__ + "."
254+
if module.split(".")[0] == "gpytorch":
255+
module = "~" + module
256+
elif module == "linear_operator.operators._linear_operator.":
257+
module = "~linear_operator."
258+
elif module == "builtins.":
259+
module = ""
260+
res = f"{module}{annotation.__name__}"
261+
return res
262+
263+
264+
# Convert jaxtyping dimensions into strings
265+
def _dim_to_str(dim):
266+
if isinstance(dim, jaxtyping.array_types._NamedVariadicDim):
267+
return "..."
268+
elif isinstance(dim, jaxtyping.array_types._FixedDim):
269+
res = str(dim.size)
270+
if dim.broadcastable:
271+
res = "#" + res
272+
return res
273+
elif isinstance(dim, jaxtyping.array_types._SymbolicDim):
274+
expr = code_deparse(dim.expr).text.strip().split("return ")[1]
275+
return f"({expr})"
276+
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
277+
return "..."
278+
else:
279+
res = str(dim.name)
280+
if dim.broadcastable:
281+
res = "#" + res
282+
return res
283+
284+
285+
# Function to format type hints
243286
def _process(annotation, config):
244287
"""
245288
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
246289
This function is a bit hacky, and specific to the type annotations we use most frequently.
290+
247291
This function is recursive.
248292
"""
249293
# Simple/base case: any string annotation is ready to go
250294
if type(annotation) == str:
251295
return annotation
252296

297+
# Jaxtyping: shaped tensors or linear operator
298+
elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__:
299+
cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type)
300+
shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims])
301+
return f"{cls_annotation} ({shape})"
302+
253303
# Convert Ellipsis into "..."
254304
elif annotation == Ellipsis:
255305
return "..."
256306

257-
# Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
258-
# For external classes, the format will be e.g. "torch.Tensor"
259-
# For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
260-
# For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
307+
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
261308
elif hasattr(annotation, "__name__"):
262-
module = annotation.__module__ + "."
263-
if module.split(".")[0] == "linear_operator":
264-
if annotation.__name__.endswith("LinearOperator"):
265-
module = "~linear_operator."
266-
elif annotation.__name__.endswith("LinearOperator"):
267-
module = "~linear_operator.operators."
268-
else:
269-
module = "~" + module
270-
elif module.split(".")[0] == "gpytorch":
271-
module = "~" + module
272-
elif module == "builtins.":
273-
module = ""
274-
res = f"{module}{annotation.__name__}"
309+
res = _convert_internal_and_external_class_to_strings(annotation)
310+
311+
elif str(annotation).startswith("typing.Callable"):
312+
if len(annotation.__args__) == 2:
313+
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
314+
else:
315+
res = "Callable"
275316

276317
# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
277318
# Also, convert any Optional[*A*] into "*A*, optional"
@@ -291,33 +332,14 @@ def _process(annotation, config):
291332
args = list(annotation.__args__)
292333
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"
293334

294-
# Convert any List[*A*] into "list(*A*)"
295-
elif str(annotation).startswith("typing.List"):
296-
arg = annotation.__args__[0]
297-
res = "list(" + _process(arg, config) + ")"
298-
299-
# Convert any List[*A*] into "list(*A*)"
300-
elif str(annotation).startswith("typing.Dict"):
301-
res = str(annotation)
302-
303-
# Convert any Iterable[*A*] into "iterable(*A*)"
304-
elif str(annotation).startswith("typing.Iterable"):
305-
arg = annotation.__args__[0]
306-
res = "iterable(" + _process(arg, config) + ")"
307-
308-
# Handle "Callable"
309-
elif str(annotation).startswith("typing.Callable"):
310-
res = "callable"
311-
312-
# Handle "Any"
313-
elif str(annotation).startswith("typing.Any"):
314-
res = ""
335+
# Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]"
336+
elif str(annotation).startswith("typing.Iterable") or str(annotation).startswith("typing.List"):
337+
arg = list(annotation.__args__)[0]
338+
res = f"[{_process(arg, config)}, ...]"
315339

316-
# Special cases for forward references.
317-
# This is brittle, as it only contains case for a select few forward refs
318-
# All others that aren't caught by this are handled by the default case
319-
elif isinstance(annotation, ForwardRef):
320-
res = str(annotation.__forward_arg__)
340+
# Callable typing annotation
341+
elif str(annotation).startswith("typing."):
342+
return str(annotation)[7:]
321343

322344
# For everything we didn't catch: use the simplist string representation
323345
else:
1.96 KB
Binary file not shown.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ max-line-length = 120
33

44
[flake8]
55
max-line-length = 120
6-
ignore = E203, F403, F405, E731, E741, W503, W605
6+
ignore = E203, E731, E741, F403, F405, F722, W503, W605
77
exclude =
88
build,examples
99

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def find_version(*file_paths):
3939

4040
torch_min = "1.11"
4141
install_requires = [
42+
"jaxtyping>=0.2.9",
4243
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
4344
"scikit-learn",
4445
"scipy",
@@ -88,6 +89,7 @@ def find_version(*file_paths):
8889
"sphinx<=6.2.1",
8990
"sphinx_autodoc_typehints<=1.23.0",
9091
"sphinx_rtd_theme<0.5",
92+
"uncompyle6<=3.9.0",
9193
],
9294
"examples": ["ipython", "jupyter", "matplotlib", "scipy", "torchvision", "tqdm"],
9395
"keops": ["pykeops>=1.1.1"],

0 commit comments

Comments
 (0)