File tree Expand file tree Collapse file tree 2 files changed +2
-33
lines changed
Expand file tree Collapse file tree 2 files changed +2
-33
lines changed Original file line number Diff line number Diff line change 2020import sphinx_rtd_theme # noqa
2121import warnings
2222
23- import jaxtyping
24-
2523
2624def read (* names , ** kwargs ):
2725 with io .open (
@@ -262,27 +260,6 @@ def _convert_internal_and_external_class_to_strings(annotation):
262260 return res
263261
264262
265- # Convert jaxtyping dimensions into strings
266- def _dim_to_str (dim ):
267- if isinstance (dim , jaxtyping ._array_types ._NamedVariadicDim ):
268- return "..."
269- elif isinstance (dim , jaxtyping ._array_types ._FixedDim ):
270- res = str (dim .size )
271- if dim .broadcastable :
272- res = "#" + res
273- return res
274- elif isinstance (dim , jaxtyping ._array_types ._SymbolicDim ):
275- expr = dim .elem
276- return f"({ expr } )"
277- elif "jaxtyping" not in str (dim .__class__ ): # Probably the case that we have an ellipsis
278- return "..."
279- else :
280- res = str (dim .name )
281- if dim .broadcastable :
282- res = "#" + res
283- return res
284-
285-
286263# Function to format type hints
287264def _process (annotation , config ):
288265 """
@@ -295,12 +272,6 @@ def _process(annotation, config):
295272 if isinstance (annotation , str ):
296273 return annotation
297274
298- # Jaxtyping: shaped tensors or linear operator
299- elif hasattr (annotation , "__module__" ) and "jaxtyping" == annotation .__module__ :
300- cls_annotation = _convert_internal_and_external_class_to_strings (annotation .array_type )
301- shape = " x " .join ([_dim_to_str (dim ) for dim in annotation .dims ])
302- return f"{ cls_annotation } ({ shape } )"
303-
304275 # Convert Ellipsis into "..."
305276 elif annotation == Ellipsis :
306277 return "..."
Original file line number Diff line number Diff line change 11#!/usr/bin/env python3
22
3- import io
43import os
54import re
65import sys
2625# Get version
2726def find_version (* file_paths ):
2827 try :
29- with io . open (os .path .join (os .path .dirname (__file__ ), * file_paths ), encoding = "utf8" ) as fp :
28+ with open (os .path .join (os .path .dirname (__file__ ), * file_paths ), encoding = "utf8" ) as fp :
3029 version_file = fp .read ()
3130 version_match = re .search (r"^__version__ = version = ['\"]([^'\"]*)['\"]" , version_file , re .M )
3231 return version_match .group (1 )
@@ -39,11 +38,10 @@ def find_version(*file_paths):
3938
4039torch_min = "2.0"
4140install_requires = [
42- "jaxtyping" ,
4341 "mpmath>=0.19,<=1.3" , # avoid incompatibiltiy with torch+sympy with mpmath 1.4
4442 "scikit-learn" ,
4543 "scipy>=1.6.0" ,
46- "linear_operator>=0.6" ,
44+ "linear_operator>=0.6.1 " ,
4745]
4846# if recent dev version of PyTorch is installed, no need to install stable
4947try :
You can’t perform that action at this time.
0 commit comments