Skip to content

Commit cad5862

Browse files
committed
Update scikit-learn unit test models for scikit-learn>=1.3.0
Signed-off-by: Beat Buesser <[email protected]>
1 parent e0ed2eb commit cad5862

22 files changed

+28
-14
lines changed

tests/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import unittest
2929
import warnings
3030

31+
import sklearn
3132
import numpy as np
3233

3334
from art.estimators.classification.tensorflow import TensorFlowV2Classifier
@@ -1746,6 +1747,13 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17461747
ScikitlearnSVC,
17471748
)
17481749

1750+
sklearn_version = list(map(int, sklearn.__version__.split(".")))
1751+
sklearn_ge_1_3_0 = sklearn_version[0] == 1 and sklearn_version[1] >= 3
1752+
if sklearn_ge_1_3_0:
1753+
suffix = "-ge-1.3.0"
1754+
else:
1755+
suffix = ""
1756+
17491757
available_models = {
17501758
"decisionTreeClassifier": ScikitlearnDecisionTreeClassifier,
17511759
# "extraTreeClassifier": ScikitlearnExtraTreeClassifier,
@@ -1775,7 +1783,7 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17751783
os.path.join(
17761784
os.path.dirname(os.path.dirname(__file__)),
17771785
"utils/resources/models/scikit/",
1778-
"scikit-" + model_name + "-iris-clipped.pickle",
1786+
"scikit-" + model_name + "-iris-clipped" + suffix + ".pickle",
17791787
),
17801788
"rb",
17811789
)
@@ -1788,7 +1796,7 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17881796
os.path.join(
17891797
os.path.dirname(os.path.dirname(__file__)),
17901798
"utils/resources/models/scikit/",
1791-
"scikit-" + model_name + "-iris-unclipped.pickle",
1799+
"scikit-" + model_name + "-iris-unclipped" + suffix + ".pickle",
17921800
),
17931801
"rb",
17941802
)

utils/resources/create_model_weights.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
18-
import tensorflow as tf
1918
import os
2019
import pickle
20+
21+
import tensorflow as tf
2122
from tensorflow.keras.models import Sequential
2223
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
24+
import sklearn
2325
from sklearn.linear_model import LogisticRegression
24-
from art.estimators.classification import SklearnClassifier
2526
from sklearn.svm import SVC, LinearSVC
2627
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
2728
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, ExtraTreesClassifier
@@ -140,42 +141,47 @@ def create_scikit_model_weights():
140141
}
141142

142143
clipped_models = {
143-
model_name: SklearnClassifier(model=model, clip_values=(0, 1)) for model_name, model in model_list.items()
144+
model_name: model for model_name, model in model_list.items()
144145
}
145-
unclipped_models = {model_name: SklearnClassifier(model=model) for model_name, model in model_list.items()}
146+
unclipped_models = {model_name: model for model_name, model in model_list.items()}
146147

147148
(x_train_iris, y_train_iris), (_, _), _, _ = load_dataset("iris")
148149

150+
y_train_iris = np.argmax(y_train_iris, axis=1)
151+
152+
print(sklearn.__version__)
153+
dfg
154+
149155
for model_name, model in clipped_models.items():
150-
model.fit(x=x_train_iris, y=y_train_iris)
156+
model.fit(X=x_train_iris, y=y_train_iris)
151157
pickle.dump(
152158
model,
153159
open(
154160
os.path.join(
155161
os.path.dirname(os.path.dirname(__file__)),
156-
"utils/resources/models/scikit/",
157-
model_name + "iris_clipped.sav",
162+
"resources/models/scikit/",
163+
"scikit-" + model_name + "-iris-clipped-ge-1.3.0.pickle",
158164
),
159165
"wb",
160166
),
161167
)
162168

163169
for model_name, model in unclipped_models.items():
164-
model.fit(x=x_train_iris, y=y_train_iris)
170+
model.fit(X=x_train_iris, y=y_train_iris)
165171
pickle.dump(
166172
model,
167173
open(
168174
os.path.join(
169175
os.path.dirname(os.path.dirname(__file__)),
170-
"utils/resources/models/scikit/",
171-
model_name + "iris_unclipped.sav",
176+
"resources/models/scikit/",
177+
"scikit-" + model_name + "-iris-unclipped-ge-1.3.0.pickle",
172178
),
173179
"wb",
174180
),
175181
)
176182

177183

178184
if __name__ == "__main__":
179-
main_mnist_binary()
185+
# main_mnist_binary()
180186
create_scikit_model_weights()
181-
main_diabetes()
187+
# main_diabetes()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)