88from typing import TYPE_CHECKING , List , Optional , Tuple
99
1010import torch
11- from sklearn .base import BaseEstimator , ClassifierMixin
1211
1312if TYPE_CHECKING :
1413 from beartype .typing import Any , Literal
1514 from jaxtyping import Bool , Float , Int , Real
1615
1716from ..manifolds import ProductManifold
17+ from ._base import BasePredictor
1818from ._midpoint import midpoint
1919
2020
@@ -259,7 +259,7 @@ def __init__(
259259 self .right = right
260260
261261
262- class ProductSpaceDT (BaseEstimator , ClassifierMixin ):
262+ class ProductSpaceDT (BasePredictor ):
263263 """Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data."""
264264
265265 def __init__ (
@@ -274,7 +274,12 @@ def __init__(
274274 batch_size : int | None = None ,
275275 n_features : Literal ["d" , "d_choose_2" ] = "d" ,
276276 ablate_midpoints : bool = False ,
277+ random_state : int | None = None ,
278+ device : str | None = None ,
277279 ):
280+ # Initialize the base class
281+ super ().__init__ (pm = pm , task = task , random_state = random_state , device = device )
282+
278283 # Raise error if manifold is stereographic
279284 if pm .is_stereographic :
280285 raise ValueError ("Stereographic manifolds are not supported. Use a different representation." )
@@ -637,7 +642,7 @@ def score(
637642 return ((self .predict (X ) - y ) ** 2 * sample_weight ).mean ()
638643
639644
640- class ProductSpaceRF (BaseEstimator , ClassifierMixin ):
645+ class ProductSpaceRF (BasePredictor ):
641646 """Random Forest in the product space."""
642647
643648 def __init__ (
@@ -657,7 +662,19 @@ def __init__(
657662 batch_size : int | None = None ,
658663 random_state : int | None = None ,
659664 n_jobs : int = - 1 ,
665+ device : str | None = None ,
660666 ):
667+ # Initialize the base class
668+ super ().__init__ (pm = pm , task = task , random_state = random_state , device = device )
669+
670+ # Raise error if manifold is stereographic
671+ if pm .is_stereographic :
672+ raise ValueError ("Stereographic manifolds are not supported. Use a different representation." )
673+ if task == "link_prediction" :
674+ raise ValueError (
675+ "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
676+ )
677+
661678 # Tree hyperparameters
662679 tree_kwargs : Dict [str , Any ] = {}
663680 self .pm = tree_kwargs ["pm" ] = pm
0 commit comments