Skip to content

Commit 88bfc1f

Browse files
committed
add new test-functions to init
1 parent 8c46c58 commit 88bfc1f

File tree

1 file changed

+60
-4
lines changed
  • src/surfaces/test_functions/machine_learning

1 file changed

+60
-4
lines changed

src/surfaces/test_functions/machine_learning/__init__.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _check_sklearn():
2020

2121

2222
if _HAS_SKLEARN:
23+
# Tabular functions
2324
from .tabular import (
2425
# Classification
2526
DecisionTreeClassifierFunction,
@@ -35,35 +36,90 @@ def _check_sklearn():
3536
SVMRegressorFunction,
3637
)
3738

38-
__all__ = [
39+
# Time-series functions
40+
from .timeseries import (
41+
# Forecasting
42+
GradientBoostingForecasterFunction,
43+
RandomForestForecasterFunction,
3944
# Classification
45+
RandomForestTSClassifierFunction,
46+
KNNTSClassifierFunction,
47+
)
48+
49+
# Image functions (sklearn-based)
50+
from .image import (
51+
SVMImageClassifierFunction,
52+
RandomForestImageClassifierFunction,
53+
)
54+
55+
__all__ = [
56+
# Tabular - Classification
4057
"DecisionTreeClassifierFunction",
4158
"GradientBoostingClassifierFunction",
4259
"KNeighborsClassifierFunction",
4360
"RandomForestClassifierFunction",
4461
"SVMClassifierFunction",
45-
# Regression
62+
# Tabular - Regression
4663
"DecisionTreeRegressorFunction",
4764
"GradientBoostingRegressorFunction",
4865
"KNeighborsRegressorFunction",
4966
"RandomForestRegressorFunction",
5067
"SVMRegressorFunction",
68+
# Time-series - Forecasting
69+
"GradientBoostingForecasterFunction",
70+
"RandomForestForecasterFunction",
71+
# Time-series - Classification
72+
"RandomForestTSClassifierFunction",
73+
"KNNTSClassifierFunction",
74+
# Image - Classification (sklearn)
75+
"SVMImageClassifierFunction",
76+
"RandomForestImageClassifierFunction",
5177
]
5278

5379
machine_learning_functions = [
54-
# Classification
80+
# Tabular - Classification
5581
DecisionTreeClassifierFunction,
5682
GradientBoostingClassifierFunction,
5783
KNeighborsClassifierFunction,
5884
RandomForestClassifierFunction,
5985
SVMClassifierFunction,
60-
# Regression
86+
# Tabular - Regression
6187
DecisionTreeRegressorFunction,
6288
GradientBoostingRegressorFunction,
6389
KNeighborsRegressorFunction,
6490
RandomForestRegressorFunction,
6591
SVMRegressorFunction,
92+
# Time-series - Forecasting
93+
GradientBoostingForecasterFunction,
94+
RandomForestForecasterFunction,
95+
# Time-series - Classification
96+
RandomForestTSClassifierFunction,
97+
KNNTSClassifierFunction,
98+
# Image - Classification (sklearn)
99+
SVMImageClassifierFunction,
100+
RandomForestImageClassifierFunction,
66101
]
102+
103+
# CNN image classifiers (require tensorflow)
104+
try:
105+
from .image import SimpleCNNClassifierFunction, DeepCNNClassifierFunction
106+
107+
__all__.extend(
108+
[
109+
"SimpleCNNClassifierFunction",
110+
"DeepCNNClassifierFunction",
111+
]
112+
)
113+
machine_learning_functions.extend(
114+
[
115+
SimpleCNNClassifierFunction,
116+
DeepCNNClassifierFunction,
117+
]
118+
)
119+
_HAS_TENSORFLOW = True
120+
except ImportError:
121+
_HAS_TENSORFLOW = False
122+
67123
else:
68124
__all__ = []
69125
machine_learning_functions = []

0 commit comments

Comments
 (0)