Skip to content

Commit cad5862

Browse files
committed
Update scikit-learn unit test models for scikit-learn>=1.3.0
Signed-off-by: Beat Buesser <[email protected]>
1 parent e0ed2eb commit cad5862

22 files changed

+28
-14
lines changed

tests/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import unittest
2929
import warnings
3030

31+
import sklearn
3132
import numpy as np
3233

3334
from art.estimators.classification.tensorflow import TensorFlowV2Classifier
@@ -1746,6 +1747,13 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17461747
ScikitlearnSVC,
17471748
)
17481749

1750+
sklearn_version = list(map(int, sklearn.__version__.split(".")))
1751+
sklearn_ge_1_3_0 = sklearn_version[0] == 1 and sklearn_version[1] >= 3
1752+
if sklearn_ge_1_3_0:
1753+
suffix = "-ge-1.3.0"
1754+
else:
1755+
suffix = ""
1756+
17491757
available_models = {
17501758
"decisionTreeClassifier": ScikitlearnDecisionTreeClassifier,
17511759
# "extraTreeClassifier": ScikitlearnExtraTreeClassifier,
@@ -1775,7 +1783,7 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17751783
os.path.join(
17761784
os.path.dirname(os.path.dirname(__file__)),
17771785
"utils/resources/models/scikit/",
1778-
"scikit-" + model_name + "-iris-clipped.pickle",
1786+
"scikit-" + model_name + "-iris-clipped" + suffix + ".pickle",
17791787
),
17801788
"rb",
17811789
)
@@ -1788,7 +1796,7 @@ def get_tabular_classifier_scikit_list(clipped=False, model_list_names=None):
17881796
os.path.join(
17891797
os.path.dirname(os.path.dirname(__file__)),
17901798
"utils/resources/models/scikit/",
1791-
"scikit-" + model_name + "-iris-unclipped.pickle",
1799+
"scikit-" + model_name + "-iris-unclipped" + suffix + ".pickle",
17921800
),
17931801
"rb",
17941802
)

utils/resources/create_model_weights.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
18-
import tensorflow as tf
1918
import os
2019
import pickle
20+
21+
import tensorflow as tf
2122
from tensorflow.keras.models import Sequential
2223
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
24+
import sklearn
2325
from sklearn.linear_model import LogisticRegression
24-
from art.estimators.classification import SklearnClassifier
2526
from sklearn.svm import SVC, LinearSVC
2627
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
2728
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, ExtraTreesClassifier
@@ -140,42 +141,47 @@ def create_scikit_model_weights():
140141
}
141142

142143
clipped_models = {
143-
model_name: SklearnClassifier(model=model, clip_values=(0, 1)) for model_name, model in model_list.items()
144+
model_name: model for model_name, model in model_list.items()
144145
}
145-
unclipped_models = {model_name: SklearnClassifier(model=model) for model_name, model in model_list.items()}
146+
unclipped_models = {model_name: model for model_name, model in model_list.items()}
146147

147148
(x_train_iris, y_train_iris), (_, _), _, _ = load_dataset("iris")
148149

150+
y_train_iris = np.argmax(y_train_iris, axis=1)
151+
152+
print(sklearn.__version__)
153+
dfg
154+
149155
for model_name, model in clipped_models.items():
150-
model.fit(x=x_train_iris, y=y_train_iris)
156+
model.fit(X=x_train_iris, y=y_train_iris)
151157
pickle.dump(
152158
model,
153159
open(
154160
os.path.join(
155161
os.path.dirname(os.path.dirname(__file__)),
156-
"utils/resources/models/scikit/",
157-
model_name + "iris_clipped.sav",
162+
"resources/models/scikit/",
163+
"scikit-" + model_name + "-iris-clipped-ge-1.3.0.pickle",
158164
),
159165
"wb",
160166
),
161167
)
162168

163169
for model_name, model in unclipped_models.items():
164-
model.fit(x=x_train_iris, y=y_train_iris)
170+
model.fit(X=x_train_iris, y=y_train_iris)
165171
pickle.dump(
166172
model,
167173
open(
168174
os.path.join(
169175
os.path.dirname(os.path.dirname(__file__)),
170-
"utils/resources/models/scikit/",
171-
model_name + "iris_unclipped.sav",
176+
"resources/models/scikit/",
177+
"scikit-" + model_name + "-iris-unclipped-ge-1.3.0.pickle",
172178
),
173179
"wb",
174180
),
175181
)
176182

177183

178184
if __name__ == "__main__":
179-
main_mnist_binary()
185+
# main_mnist_binary()
180186
create_scikit_model_weights()
181-
main_diabetes()
187+
# main_diabetes()

0 commit comments

Comments
 (0)