Skip to content

Commit 1a780b6

Browse files
committed
Basic refactor to make predictors inherit from BasePredictor class
1 parent a275456 commit 1a780b6

File tree

5 files changed

+51
-10
lines changed

5 files changed

+51
-10
lines changed

manify/predictors/_base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
import torch
9-
from sklearn.base import BaseEstimator
9+
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1010

1111
if TYPE_CHECKING:
1212
from beartype.typing import Literal
@@ -34,17 +34,28 @@ class BasePredictor(BaseEstimator, ABC):
3434
def __init__(
3535
self,
3636
pm: ProductManifold,
37-
task: Literal["classification", "regression"],
37+
task: Literal["classification", "regression", "link_prediction"],
3838
random_state: int | None = None,
3939
device: str | None = None,
4040
) -> None:
4141
self.pm = pm
4242
self.task = task
4343
self.random_state = random_state
4444
self.device = pm.device if device is None else device
45-
self.loss_history_: Dict[str, List[float]] = {}
45+
self.loss_history_: dict[str, list[float]] = {}
4646
self.is_fitted_: bool = False
4747

48+
# Initialize appropriate base class depending on task
49+
if task == "classification":
50+
ClassifierMixin.__init__(self)
51+
elif task == "regression":
52+
RegressorMixin.__init__(self)
53+
elif task == "link_prediction":
54+
# For link prediction, we also use ClassifierMixin, as we think of it as binary classificaiton.
55+
ClassifierMixin.__init__(self)
56+
else:
57+
raise ValueError(f"Unknown task type: {task}")
58+
4859
@abstractmethod
4960
def fit(
5061
self, X: Float[torch.Tensor, "n_points n_features"], y: Float[torch.Tensor, "n_points n_classes"]

manify/predictors/decision_tree.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from typing import TYPE_CHECKING, List, Optional, Tuple
99

1010
import torch
11-
from sklearn.base import BaseEstimator, ClassifierMixin
1211

1312
if TYPE_CHECKING:
1413
from beartype.typing import Any, Literal
1514
from jaxtyping import Bool, Float, Int, Real
1615

1716
from ..manifolds import ProductManifold
17+
from ._base import BasePredictor
1818
from ._midpoint import midpoint
1919

2020

@@ -259,7 +259,7 @@ def __init__(
259259
self.right = right
260260

261261

262-
class ProductSpaceDT(BaseEstimator, ClassifierMixin):
262+
class ProductSpaceDT(BasePredictor):
263263
"""Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data."""
264264

265265
def __init__(
@@ -274,7 +274,12 @@ def __init__(
274274
batch_size: int | None = None,
275275
n_features: Literal["d", "d_choose_2"] = "d",
276276
ablate_midpoints: bool = False,
277+
random_state: int | None = None,
278+
device: str | None = None,
277279
):
280+
# Initialize the base class
281+
super().__init__(pm=pm, task=task, random_state=random_state, device=device)
282+
278283
# Raise error if manifold is stereographic
279284
if pm.is_stereographic:
280285
raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
@@ -637,7 +642,7 @@ def score(
637642
return ((self.predict(X) - y) ** 2 * sample_weight).mean()
638643

639644

640-
class ProductSpaceRF(BaseEstimator, ClassifierMixin):
645+
class ProductSpaceRF(BasePredictor):
641646
"""Random Forest in the product space."""
642647

643648
def __init__(
@@ -657,7 +662,19 @@ def __init__(
657662
batch_size: int | None = None,
658663
random_state: int | None = None,
659664
n_jobs: int = -1,
665+
device: str | None = None,
660666
):
667+
# Initialize the base class
668+
super().__init__(pm=pm, task=task, random_state=random_state, device=device)
669+
670+
# Raise error if manifold is stereographic
671+
if pm.is_stereographic:
672+
raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
673+
if task == "link_prediction":
674+
raise ValueError(
675+
"Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
676+
)
677+
661678
# Tree hyperparameters
662679
tree_kwargs: Dict[str, Any] = {}
663680
self.pm = tree_kwargs["pm"] = pm

manify/predictors/kappa_gcn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from beartype.typing import Callable, Literal
1313
from jaxtyping import Float, Real
1414

15+
from ._base import BasePredictor
1516
from ..manifolds import Manifold, ProductManifold
1617

1718
# TQDM: notebook or regular
@@ -142,7 +143,7 @@ def forward(
142143
return AXW
143144

144145

145-
class KappaGCN(torch.nn.Module):
146+
class KappaGCN(torch.nn.Module, BasePredictor):
146147
"""Implementation for the Kappa GCN.
147148
148149
Parameters
@@ -159,8 +160,11 @@ def __init__(
159160
hidden_dims: list[int] | None = None,
160161
nonlinearity: Callable = torch.relu,
161162
task: Literal["classification", "regression", "link_prediction"] = "classification",
163+
random_state: int | None = None,
164+
device: str | None = None,
162165
):
163-
super().__init__()
166+
torch.nn.Module.__init__(self)
167+
BasePredictor.__init__(self, pm=pm, task=task, random_state=random_state, device=device)
164168
self.pm = pm
165169
self.task = task
166170

manify/predictors/perceptron.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from typing import TYPE_CHECKING
66

77
import torch
8-
from sklearn.base import BaseEstimator, ClassifierMixin
98

109
if TYPE_CHECKING:
1110
from jaxtyping import Float, Int
1211

1312
from ..manifolds import ProductManifold
13+
from ._base import BasePredictor
1414
from ._kernel import product_kernel
1515

1616

17-
class ProductSpacePerceptron(BaseEstimator, ClassifierMixin):
17+
class ProductSpacePerceptron(BasePredictor):
1818
"""A product-space perceptron model for multiclass classification in the product manifold space."""
1919

2020
def __init__(
@@ -23,7 +23,12 @@ def __init__(
2323
max_epochs: int = 1_000,
2424
patience: int = 5,
2525
weights: Float[torch.Tensor, "n_manifolds"] | None = None,
26+
task: str = "classification",
27+
random_state: int | None = None,
28+
device: str | None = None,
2629
):
30+
# Initialize base class
31+
super().__init__(pm, task=task, random_state=random_state, device=device)
2732
self.pm = pm # ProductManifold instance
2833
self.max_epochs = max_epochs
2934
self.patience = patience # Number of consecutive epochs without improvement to consider convergence

manify/predictors/svm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from jaxtyping import Float, Int
1515

1616
from ..manifolds import ProductManifold
17+
from ._base import BasePredictor
1718
from ._kernel import product_kernel
1819

1920

@@ -30,6 +31,9 @@ def __init__(
3031
task: Literal["classification", "regression"] = "classification",
3132
epsilon: float = 1e-5,
3233
):
34+
# Initialize base class
35+
super().__init__(pm, task=task, random_state=random_state, device=device)
36+
3337
self.pm = pm
3438
self.h_constraints = h_constraints
3539
self.s_constraints = s_constraints

0 commit comments

Comments
 (0)