Skip to content

Commit 15c66bc

Browse files
committed
Move type alias inside TYPE_CHECKING check
1 parent 365d06b commit 15c66bc

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

manify/utils/benchmarks.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,32 @@
1212
if TYPE_CHECKING:
1313
from beartype.typing import Literal, TypeAlias
1414
from jaxtyping import Float, Real
15+
16+
MODELTYPE: TypeAlias = Literal[
17+
"sklearn_dt",
18+
"sklearn_rf",
19+
"product_dt",
20+
"product_rf",
21+
"tangent_dt",
22+
"tangent_rf",
23+
"knn",
24+
"ps_perceptron",
25+
"svm",
26+
"ps_svm",
27+
"kappa_mlp",
28+
"tangent_mlp",
29+
"ambient_mlp",
30+
"tangent_gcn",
31+
"ambient_gcn",
32+
"kappa_gcn",
33+
"ambient_mlr",
34+
"tangent_mlr",
35+
"kappa_mlr",
36+
"single_manifold_rf",
37+
]
38+
SCORETYPE: TypeAlias = Literal["accuracy", "f1-micro", "f1-macro", "mse", "percent_rmse", "time"]
39+
TASKTYPE: TypeAlias = Literal["classification", "regression", "link_prediction"]
40+
1541
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
1642
from sklearn.linear_model import SGDClassifier, SGDRegressor
1743
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, root_mean_squared_error
@@ -26,31 +52,6 @@
2652
from ..predictors.perceptron import ProductSpacePerceptron
2753
from ..predictors.svm import ProductSpaceSVM
2854

29-
MODELTYPE: TypeAlias = Literal[
30-
"sklearn_dt",
31-
"sklearn_rf",
32-
"product_dt",
33-
"product_rf",
34-
"tangent_dt",
35-
"tangent_rf",
36-
"knn",
37-
"ps_perceptron",
38-
"svm",
39-
"ps_svm",
40-
"kappa_mlp",
41-
"tangent_mlp",
42-
"ambient_mlp",
43-
"tangent_gcn",
44-
"ambient_gcn",
45-
"kappa_gcn",
46-
"ambient_mlr",
47-
"tangent_mlr",
48-
"kappa_mlr",
49-
"single_manifold_rf",
50-
]
51-
SCORETYPE: TypeAlias = Literal["accuracy", "f1-micro", "f1-macro", "mse", "percent_rmse", "time"]
52-
TASKTYPE: TypeAlias = Literal["classification", "regression", "link_prediction"]
53-
5455

5556
def _score(
5657
_X: Float[torch.Tensor | np.ndarray, "batch dim"] | None,

0 commit comments

Comments
 (0)