Skip to content

Commit 7be8e3e

Browse files
authored
Merge pull request #2723 from Balandat/v1152_bump
Remove jaxtyping dep, bump linear_operator req to v0.6.1
2 parents fc9f8c1 + 927987d commit 7be8e3e

File tree

2 files changed

+2
-33
lines changed

2 files changed

+2
-33
lines changed

docs/source/conf.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import sphinx_rtd_theme # noqa
2121
import warnings
2222

23-
import jaxtyping
24-
2523

2624
def read(*names, **kwargs):
2725
with io.open(
@@ -262,27 +260,6 @@ def _convert_internal_and_external_class_to_strings(annotation):
262260
return res
263261

264262

265-
# Convert jaxtyping dimensions into strings
266-
def _dim_to_str(dim):
267-
if isinstance(dim, jaxtyping._array_types._NamedVariadicDim):
268-
return "..."
269-
elif isinstance(dim, jaxtyping._array_types._FixedDim):
270-
res = str(dim.size)
271-
if dim.broadcastable:
272-
res = "#" + res
273-
return res
274-
elif isinstance(dim, jaxtyping._array_types._SymbolicDim):
275-
expr = dim.elem
276-
return f"({expr})"
277-
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
278-
return "..."
279-
else:
280-
res = str(dim.name)
281-
if dim.broadcastable:
282-
res = "#" + res
283-
return res
284-
285-
286263
# Function to format type hints
287264
def _process(annotation, config):
288265
"""
@@ -295,12 +272,6 @@ def _process(annotation, config):
295272
if isinstance(annotation, str):
296273
return annotation
297274

298-
# Jaxtyping: shaped tensors or linear operator
299-
elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__:
300-
cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type)
301-
shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims])
302-
return f"{cls_annotation} ({shape})"
303-
304275
# Convert Ellipsis into "..."
305276
elif annotation == Ellipsis:
306277
return "..."

setup.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python3
22

3-
import io
43
import os
54
import re
65
import sys
@@ -26,7 +25,7 @@
2625
# Get version
2726
def find_version(*file_paths):
2827
try:
29-
with io.open(os.path.join(os.path.dirname(__file__), *file_paths), encoding="utf8") as fp:
28+
with open(os.path.join(os.path.dirname(__file__), *file_paths), encoding="utf8") as fp:
3029
version_file = fp.read()
3130
version_match = re.search(r"^__version__ = version = ['\"]([^'\"]*)['\"]", version_file, re.M)
3231
return version_match.group(1)
@@ -39,11 +38,10 @@ def find_version(*file_paths):
3938

4039
torch_min = "2.0"
4140
install_requires = [
42-
"jaxtyping",
4341
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
4442
"scikit-learn",
4543
"scipy>=1.6.0",
46-
"linear_operator>=0.6",
44+
"linear_operator>=0.6.1",
4745
]
4846
# if recent dev version of PyTorch is installed, no need to install stable
4947
try:

0 commit comments

Comments
 (0)