Skip to content

Commit 927987d

Browse files
committed
Remove jaxtyping logic from sphinx config
`
1 parent 7511909 commit 927987d

File tree

1 file changed

+0
-29
lines changed

1 file changed

+0
-29
lines changed

docs/source/conf.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import sphinx_rtd_theme # noqa
2121
import warnings
2222

23-
import jaxtyping
24-
2523

2624
def read(*names, **kwargs):
2725
with io.open(
@@ -262,27 +260,6 @@ def _convert_internal_and_external_class_to_strings(annotation):
262260
return res
263261

264262

265-
# Convert jaxtyping dimensions into strings
266-
def _dim_to_str(dim):
267-
if isinstance(dim, jaxtyping._array_types._NamedVariadicDim):
268-
return "..."
269-
elif isinstance(dim, jaxtyping._array_types._FixedDim):
270-
res = str(dim.size)
271-
if dim.broadcastable:
272-
res = "#" + res
273-
return res
274-
elif isinstance(dim, jaxtyping._array_types._SymbolicDim):
275-
expr = dim.elem
276-
return f"({expr})"
277-
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
278-
return "..."
279-
else:
280-
res = str(dim.name)
281-
if dim.broadcastable:
282-
res = "#" + res
283-
return res
284-
285-
286263
# Function to format type hints
287264
def _process(annotation, config):
288265
"""
@@ -295,12 +272,6 @@ def _process(annotation, config):
295272
if isinstance(annotation, str):
296273
return annotation
297274

298-
# Jaxtyping: shaped tensors or linear operator
299-
elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__:
300-
cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type)
301-
shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims])
302-
return f"{cls_annotation} ({shape})"
303-
304275
# Convert Ellipsis into "..."
305276
elif annotation == Ellipsis:
306277
return "..."

0 commit comments

Comments
 (0)