Skip to content

Commit 3ad523f

Browse files
author
Beat Buesser
committed
Move JaxClassifier to experimental
Signed-off-by: Beat Buesser <[email protected]>
1 parent 399f518 commit 3ad523f

File tree

8 files changed

+11
-5
lines changed

8 files changed

+11
-5
lines changed

art/estimators/classification/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from art.estimators.classification.ensemble import EnsembleClassifier
1515
from art.estimators.classification.GPy import GPyGaussianProcessClassifier
1616
from art.estimators.classification.keras import KerasClassifier
17-
from art.estimators.classification.jax import JaxClassifier
1817
from art.estimators.classification.lightgbm import LightGBMClassifier
1918
from art.estimators.classification.mxnet import MXClassifier
2019
from art.estimators.classification.pytorch import PyTorchClassifier

art/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
This module contains the experimental Estimator API.
3+
"""
4+
from art.experimental.estimators.jax import JaxEstimator

art/experimental/estimators/__init__.py

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Experimental classifiers.
3+
"""
4+
from art.experimental.estimators.classification.jax import JaxClassifier

art/estimators/classification/jax.py renamed to art/experimental/estimators/classification/jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
ClassGradientsMixin,
3131
ClassifierMixin,
3232
)
33-
from art.estimators.jax import JaxEstimator
33+
from art.experimental.estimators.jax import JaxEstimator
3434

3535
if TYPE_CHECKING:
3636
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
File renamed without changes.

art/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from art.estimators.classification.ensemble import EnsembleClassifier
7878
from art.estimators.classification.GPy import GPyGaussianProcessClassifier
7979
from art.estimators.classification.keras import KerasClassifier
80-
from art.estimators.classification.jax import JaxClassifier
80+
from art.experimental.estimators.classification.jax import JaxClassifier
8181
from art.estimators.classification.lightgbm import LightGBMClassifier
8282
from art.estimators.classification.mxnet import MXClassifier
8383
from art.estimators.classification.pytorch import PyTorchClassifier
@@ -982,7 +982,6 @@ def load_nursery(raw: bool = False, test_set: float = 0.2, transform_social: boo
982982
:return: Entire dataset and labels.
983983
"""
984984
import pandas as pd
985-
import sklearn.model_selection
986985
import sklearn.preprocessing
987986

988987
# Download data if needed

tests/estimators/classification/test_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax.numpy as jnp
88
from jax.scipy.special import logsumexp
99

10-
from art.estimators.classification.jax import JaxClassifier
10+
from art.experimental.estimators.classification.jax import JaxClassifier
1111
from tests.utils import ARTTestException
1212

1313

0 commit comments

Comments
 (0)