diff --git a/.gitignore b/.gitignore index 145c021..3a31fff 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build/ dist/ site/ .all_objects.cache +.venv diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 3848c14..6582b39 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -198,9 +198,18 @@ class PRNGKeyArray: return Shaped[jax.Array, ""] elif item == "ScalarLike": - import jax.typing + if getattr(typing, "GENERATING_DOCUMENTATION", False): + + class ScalarLike: + pass + + ScalarLike.__module__ = "builtins" + ScalarLike.__qualname__ = "ScalarLike" + return ScalarLike + else: + import jax.typing - return Shaped[jax.typing.ArrayLike, ""] + return Shaped[jax.typing.ArrayLike, ""] elif item == "PyTree": from ._pytree_type import PyTree