Skip to content

Commit c3b3351

Browse files
committed
Fix jaxtyping compatibiilty; isort benchmarks
1 parent 9a64177 commit c3b3351

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

manify/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import manify.manifolds
44
import manify.predictors
55

6+
# Define version
7+
__version__ = "0.0.2"
8+
69
# Dynamically check for utils dependencies
710
try:
811
import importlib.util

manify/manifolds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from __future__ import annotations
1212

1313
import warnings
14-
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
14+
from typing import (TYPE_CHECKING, Callable, List, Literal, Optional, Tuple,
15+
Union)
1516

1617
import geoopt
1718
import torch

manify/utils/benchmarks.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,19 @@
99
from sklearn.base import BaseEstimator
1010
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
1111
from sklearn.linear_model import SGDClassifier, SGDRegressor
12-
from sklearn.metrics import (
13-
accuracy_score,
14-
f1_score,
15-
mean_squared_error,
16-
root_mean_squared_error,
17-
)
12+
from sklearn.metrics import (accuracy_score, f1_score, mean_squared_error,
13+
root_mean_squared_error)
1814
from sklearn.model_selection import train_test_split
1915
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2016
from sklearn.svm import SVC, SVR
2117
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2218

2319
from ..manifolds import ProductManifold
20+
from ..predictors.decision_tree import ProductSpaceDT, ProductSpaceRF
2421
from ..predictors.kappa_gcn import KappaGCN, get_A_hat
2522
from ..predictors.perceptron import ProductSpacePerceptron
2623
from ..predictors.svm import ProductSpaceSVM
2724

28-
from ..predictors.decision_tree import ProductSpaceDT, ProductSpaceRF
29-
3025

3126
def _score(
3227
_X: Float[torch.Tensor, "n_samples n_dims"],

pyproject.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
[project.optional-dependencies]
2424
dev = [
2525
"jaxtyping",
26+
"beartype",
2627
"pytest",
2728
"mypy",
2829
"pytest-cov",
@@ -70,4 +71,14 @@ generated-members = "torch.*,nn.*"
7071
[tool.pytest.ini_options]
7172
testpaths = ["tests"]
7273
addopts = "--jaxtyping-packages=beartype.beartype"
73-
python_files = "test_*.py"
74+
python_files = "test_*.py"
75+
76+
[tool.mypy]
77+
# Specify the packages to check
78+
packages = ["manify"]
79+
80+
# Ignore missing imports for external libraries
81+
ignore_missing_imports = true
82+
83+
# Use jaxtyping-specific plugins
84+
plugins = ["jaxtyping.mypy_plugin"]

0 commit comments

Comments
 (0)