1919import sys
2020import sphinx_rtd_theme # noqa
2121import warnings
22- from typing import ForwardRef
22+
23+ import jaxtyping
24+ from uncompyle6 .semantics .fragments import code_deparse
2325
2426
2527def read (* names , ** kwargs ):
@@ -112,7 +114,8 @@ def find_version(*file_paths):
112114intersphinx_mapping = {
113115 "python" : ("https://docs.python.org/3/" , None ),
114116 "torch" : ("https://pytorch.org/docs/stable/" , None ),
115- "linear_operator" : ("https://linear-operator.readthedocs.io/en/stable/" , None ),
117+ "linear_operator" : ("https://linear-operator.readthedocs.io/en/stable/" , "linear_operator_objects.inv" ),
118+ # The local mapping here is temporary until we get a new release of linear_operator
116119}
117120
118121# Disable docstring inheritance
@@ -237,41 +240,79 @@ def find_version(*file_paths):
237240]
238241
239242
240- # -- Function to format typehints ----------------------------------------------
243+ # -- Functions to format typehints ----------------------------------------------
241244# Adapted from
242245# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py
246+
247+
248+ # Helper function
249+ # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
250+ # For external classes, the format will be e.g. "torch.Tensor"
251+ # For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
252+ def _convert_internal_and_external_class_to_strings (annotation ):
253+ module = annotation .__module__ + "."
254+ if module .split ("." )[0 ] == "gpytorch" :
255+ module = "~" + module
256+ elif module == "linear_operator.operators._linear_operator." :
257+ module = "~linear_operator."
258+ elif module == "builtins." :
259+ module = ""
260+ res = f"{ module } { annotation .__name__ } "
261+ return res
262+
263+
264+ # Convert jaxtyping dimensions into strings
265+ def _dim_to_str (dim ):
266+ if isinstance (dim , jaxtyping .array_types ._NamedVariadicDim ):
267+ return "..."
268+ elif isinstance (dim , jaxtyping .array_types ._FixedDim ):
269+ res = str (dim .size )
270+ if dim .broadcastable :
271+ res = "#" + res
272+ return res
273+ elif isinstance (dim , jaxtyping .array_types ._SymbolicDim ):
274+ expr = code_deparse (dim .expr ).text .strip ().split ("return " )[1 ]
275+ return f"({ expr } )"
276+ elif "jaxtyping" not in str (dim .__class__ ): # Probably the case that we have an ellipsis
277+ return "..."
278+ else :
279+ res = str (dim .name )
280+ if dim .broadcastable :
281+ res = "#" + res
282+ return res
283+
284+
285+ # Function to format type hints
243286def _process (annotation , config ):
244287 """
245288 A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
246289 This function is a bit hacky, and specific to the type annotations we use most frequently.
290+
247291 This function is recursive.
248292 """
249293 # Simple/base case: any string annotation is ready to go
250294 if type (annotation ) == str :
251295 return annotation
252296
297+ # Jaxtyping: shaped tensors or linear operator
298+ elif hasattr (annotation , "__module__" ) and "jaxtyping" == annotation .__module__ :
299+ cls_annotation = _convert_internal_and_external_class_to_strings (annotation .array_type )
300+ shape = " x " .join ([_dim_to_str (dim ) for dim in annotation .dims ])
301+ return f"{ cls_annotation } ({ shape } )"
302+
253303 # Convert Ellipsis into "..."
254304 elif annotation == Ellipsis :
255305 return "..."
256306
257- # Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
258- # For external classes, the format will be e.g. "torch.Tensor"
259- # For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
260- # For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
307+ # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
261308 elif hasattr (annotation , "__name__" ):
262- module = annotation .__module__ + "."
263- if module .split ("." )[0 ] == "linear_operator" :
264- if annotation .__name__ .endswith ("LinearOperator" ):
265- module = "~linear_operator."
266- elif annotation .__name__ .endswith ("LinearOperator" ):
267- module = "~linear_operator.operators."
268- else :
269- module = "~" + module
270- elif module .split ("." )[0 ] == "gpytorch" :
271- module = "~" + module
272- elif module == "builtins." :
273- module = ""
274- res = f"{ module } { annotation .__name__ } "
309+ res = _convert_internal_and_external_class_to_strings (annotation )
310+
311+ elif str (annotation ).startswith ("typing.Callable" ):
312+ if len (annotation .__args__ ) == 2 :
313+ res = f"Callable[{ _process (annotation .__args__ [0 ], config )} -> { _process (annotation .__args__ [1 ], config )} ]"
314+ else :
315+ res = "Callable"
275316
276317 # Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
277318 # Also, convert any Optional[*A*] into "*A*, optional"
@@ -291,33 +332,14 @@ def _process(annotation, config):
291332 args = list (annotation .__args__ )
292333 res = "(" + ", " .join (_process (arg , config ) for arg in args ) + ")"
293334
294- # Convert any List[*A*] into "list(*A*)"
295- elif str (annotation ).startswith ("typing.List" ):
296- arg = annotation .__args__ [0 ]
297- res = "list(" + _process (arg , config ) + ")"
298-
299- # Convert any List[*A*] into "list(*A*)"
300- elif str (annotation ).startswith ("typing.Dict" ):
301- res = str (annotation )
302-
303- # Convert any Iterable[*A*] into "iterable(*A*)"
304- elif str (annotation ).startswith ("typing.Iterable" ):
305- arg = annotation .__args__ [0 ]
306- res = "iterable(" + _process (arg , config ) + ")"
307-
308- # Handle "Callable"
309- elif str (annotation ).startswith ("typing.Callable" ):
310- res = "callable"
311-
312- # Handle "Any"
313- elif str (annotation ).startswith ("typing.Any" ):
314- res = ""
335+ # Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]"
336+ elif str (annotation ).startswith ("typing.Iterable" ) or str (annotation ).startswith ("typing.List" ):
337+ arg = list (annotation .__args__ )[0 ]
338+ res = f"[{ _process (arg , config )} , ...]"
315339
316- # Special cases for forward references.
317- # This is brittle, as it only contains case for a select few forward refs
318- # All others that aren't caught by this are handled by the default case
319- elif isinstance (annotation , ForwardRef ):
320- res = str (annotation .__forward_arg__ )
340+ # Callable typing annotation
341+ elif str (annotation ).startswith ("typing." ):
342+ return str (annotation )[7 :]
321343
322344 # For everything we didn't catch: use the simplist string representation
323345 else :
0 commit comments