Skip to content

Commit 365d06b

Browse files
committed
Make typing imports optional with TYPE_CHECKING flag
1 parent 7ed12d2 commit 365d06b

25 files changed

+120
-36
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
run: black --check manify/ --line-length 120
4444

4545
- name: Check import ordering with isort
46-
run: isort --check-only --profile black manify/ --line-width 120
46+
run: isort --check-only --diff manify/
4747

4848
- name: Run pylint
4949
run: pylint manify/

manify/clustering/fuzzy_kmeans.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020

2121
from __future__ import annotations
2222

23+
from typing import TYPE_CHECKING
24+
2325
import numpy as np
2426
import torch
25-
from beartype.typing import Literal
2627
from geoopt import ManifoldParameter
2728
from geoopt.optim import RiemannianAdam
28-
from jaxtyping import Float, Int
2929
from sklearn.base import BaseEstimator, ClusterMixin
3030

31+
if TYPE_CHECKING:
32+
from beartype.typing import Literal
33+
from jaxtyping import Float, Int
34+
3135
from ..manifolds import Manifold, ProductManifold
3236
from ..optimizers.radan import RiemannianAdan
3337

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import TYPE_CHECKING
15+
1416
import torch
15-
from jaxtyping import Float, Int
17+
18+
if TYPE_CHECKING:
19+
from jaxtyping import Float, Int
1620

1721

1822
def sampled_delta_hyperbolicity(

manify/curvature_estimation/greedy_method.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
from __future__ import annotations
88

9+
from typing import TYPE_CHECKING
10+
911
import torch
10-
from jaxtyping import Float
12+
13+
if TYPE_CHECKING:
14+
from jaxtyping import Float
1115

1216
from ..manifolds import ProductManifold
1317

manify/embedders/_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
from typing import TYPE_CHECKING
67

78
import torch
8-
from beartype.typing import Any
9-
from jaxtyping import Float
109
from sklearn.base import BaseEstimator, TransformerMixin
1110

11+
if TYPE_CHECKING:
12+
from beartype.typing import Any
13+
from jaxtyping import Float
14+
1215
from ..manifolds import ProductManifold
1316

1417

manify/embedders/_losses.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77

88
from __future__ import annotations
99

10+
from typing import TYPE_CHECKING
11+
1012
import networkx as nx
1113
import torch
12-
from jaxtyping import Float
14+
15+
if TYPE_CHECKING:
16+
from jaxtyping import Float
1317

1418
from ..manifolds import ProductManifold
1519

manify/embedders/coordinate_learning.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
import sys
1313
import warnings
14+
from typing import TYPE_CHECKING
1415

1516
import geoopt
1617
import numpy as np
1718
import torch
18-
from beartype.typing import Any
19-
from jaxtyping import Float, Int
19+
20+
if TYPE_CHECKING:
21+
from beartype.typing import Any
22+
from jaxtyping import Float, Int
2023

2124
from ..manifolds import ProductManifold
2225
from ._base import BaseEmbedder

manify/embedders/siamese.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from __future__ import annotations
1212

1313
import sys
14+
from typing import TYPE_CHECKING
1415

1516
import numpy as np
1617
import torch
17-
from jaxtyping import Float
18+
19+
if TYPE_CHECKING:
20+
from jaxtyping import Float
1821

1922
from ..manifolds import ProductManifold
2023
from ._base import BaseEmbedder

manify/embedders/vae.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from __future__ import annotations
1212

1313
import sys
14+
from typing import TYPE_CHECKING
1415

1516
import numpy as np
1617
import torch
17-
from jaxtyping import Float
18+
19+
if TYPE_CHECKING:
20+
from jaxtyping import Float
1821

1922
from ..manifolds import ProductManifold
2023
from ._base import BaseEmbedder

manify/manifolds.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from __future__ import annotations
1111

1212
import warnings
13+
from typing import TYPE_CHECKING
1314

1415
import geoopt
1516
import torch
16-
from beartype.typing import Callable, Literal
17-
from jaxtyping import Float, Real
17+
18+
if TYPE_CHECKING:
19+
from beartype.typing import Callable, Literal
20+
from jaxtyping import Float, Real
1821

1922
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributions") # Singular samples from Wishart
2023

0 commit comments

Comments
 (0)