Skip to content

Commit 602e859

Browse files
committed
add import of ml-test-functions
1 parent 182a255 commit 602e859

File tree

1 file changed

+31
-1
lines changed
  • src/surfaces/test_functions/machine_learning

1 file changed

+31
-1
lines changed

src/surfaces/test_functions/machine_learning/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _check_sklearn():
3636
SVMRegressorFunction,
3737
)
3838

39-
# Time-series functions
39+
# Time-series functions (sklearn-based)
4040
from .timeseries import (
4141
# Forecasting
4242
GradientBoostingForecasterFunction,
@@ -100,6 +100,26 @@ def _check_sklearn():
100100
RandomForestImageClassifierFunction,
101101
]
102102

103+
# sktime-based time-series functions (require sktime)
104+
try:
105+
from .timeseries import ExpSmoothingForecasterFunction, TSForestClassifierFunction
106+
107+
__all__.extend(
108+
[
109+
"ExpSmoothingForecasterFunction",
110+
"TSForestClassifierFunction",
111+
]
112+
)
113+
machine_learning_functions.extend(
114+
[
115+
ExpSmoothingForecasterFunction,
116+
TSForestClassifierFunction,
117+
]
118+
)
119+
_HAS_SKTIME = True
120+
except ImportError:
121+
_HAS_SKTIME = False
122+
103123
# CNN image classifiers (require tensorflow)
104124
try:
105125
from .image import SimpleCNNClassifierFunction, DeepCNNClassifierFunction
@@ -120,6 +140,16 @@ def _check_sklearn():
120140
except ImportError:
121141
_HAS_TENSORFLOW = False
122142

143+
# XGBoost image classifier (requires xgboost)
144+
try:
145+
from .image import XGBoostImageClassifierFunction
146+
147+
__all__.append("XGBoostImageClassifierFunction")
148+
machine_learning_functions.append(XGBoostImageClassifierFunction)
149+
_HAS_XGBOOST = True
150+
except ImportError:
151+
_HAS_XGBOOST = False
152+
123153
else:
124154
__all__ = []
125155
machine_learning_functions = []

0 commit comments

Comments
 (0)