|
15 | 15 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
16 | 16 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
17 | 17 | # SOFTWARE. |
18 | | -import tensorflow as tf |
19 | 18 | import os |
20 | 19 | import pickle |
| 20 | + |
| 21 | +import tensorflow as tf |
21 | 22 | from tensorflow.keras.models import Sequential |
22 | 23 | from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D |
23 | 24 | from sklearn.linear_model import LogisticRegression |
24 | | -from art.estimators.classification import SklearnClassifier |
25 | 25 | from sklearn.svm import SVC, LinearSVC |
26 | 26 | from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier |
27 | 27 | from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, ExtraTreesClassifier |
@@ -139,43 +139,43 @@ def create_scikit_model_weights(): |
139 | 139 | "linearSVC": LinearSVC(), |
140 | 140 | } |
141 | 141 |
|
142 | | - clipped_models = { |
143 | | - model_name: SklearnClassifier(model=model, clip_values=(0, 1)) for model_name, model in model_list.items() |
144 | | - } |
145 | | - unclipped_models = {model_name: SklearnClassifier(model=model) for model_name, model in model_list.items()} |
| 142 | + clipped_models = {model_name: model for model_name, model in model_list.items()} |
| 143 | + unclipped_models = {model_name: model for model_name, model in model_list.items()} |
146 | 144 |
|
147 | 145 | (x_train_iris, y_train_iris), (_, _), _, _ = load_dataset("iris") |
148 | 146 |
|
| 147 | + y_train_iris = np.argmax(y_train_iris, axis=1) |
| 148 | + |
149 | 149 | for model_name, model in clipped_models.items(): |
150 | | - model.fit(x=x_train_iris, y=y_train_iris) |
| 150 | + model.fit(X=x_train_iris, y=y_train_iris) |
151 | 151 | pickle.dump( |
152 | 152 | model, |
153 | 153 | open( |
154 | 154 | os.path.join( |
155 | 155 | os.path.dirname(os.path.dirname(__file__)), |
156 | | - "utils/resources/models/scikit/", |
157 | | - model_name + "iris_clipped.sav", |
| 156 | + "resources/models/scikit/", |
| 157 | + "scikit-" + model_name + "-iris-clipped-ge-1.3.0.pickle", |
158 | 158 | ), |
159 | 159 | "wb", |
160 | 160 | ), |
161 | 161 | ) |
162 | 162 |
|
163 | 163 | for model_name, model in unclipped_models.items(): |
164 | | - model.fit(x=x_train_iris, y=y_train_iris) |
| 164 | + model.fit(X=x_train_iris, y=y_train_iris) |
165 | 165 | pickle.dump( |
166 | 166 | model, |
167 | 167 | open( |
168 | 168 | os.path.join( |
169 | 169 | os.path.dirname(os.path.dirname(__file__)), |
170 | | - "utils/resources/models/scikit/", |
171 | | - model_name + "iris_unclipped.sav", |
| 170 | + "resources/models/scikit/", |
| 171 | + "scikit-" + model_name + "-iris-unclipped-ge-1.3.0.pickle", |
172 | 172 | ), |
173 | 173 | "wb", |
174 | 174 | ), |
175 | 175 | ) |
176 | 176 |
|
177 | 177 |
|
178 | 178 | if __name__ == "__main__": |
179 | | - main_mnist_binary() |
| 179 | + # main_mnist_binary() |
180 | 180 | create_scikit_model_weights() |
181 | | - main_diabetes() |
| 181 | + # main_diabetes() |
0 commit comments