|
15 | 15 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
16 | 16 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
17 | 17 | # SOFTWARE. |
18 | | -import tensorflow as tf |
19 | 18 | import os |
20 | 19 | import pickle |
| 20 | + |
| 21 | +import tensorflow as tf |
21 | 22 | from tensorflow.keras.models import Sequential |
22 | 23 | from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D |
| 24 | +import sklearn |
23 | 25 | from sklearn.linear_model import LogisticRegression |
24 | | -from art.estimators.classification import SklearnClassifier |
25 | 26 | from sklearn.svm import SVC, LinearSVC |
26 | 27 | from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier |
27 | 28 | from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, ExtraTreesClassifier |
@@ -140,42 +141,47 @@ def create_scikit_model_weights(): |
140 | 141 | } |
141 | 142 |
|
142 | 143 | clipped_models = { |
143 | | - model_name: SklearnClassifier(model=model, clip_values=(0, 1)) for model_name, model in model_list.items() |
| 144 | + model_name: model for model_name, model in model_list.items() |
144 | 145 | } |
145 | | - unclipped_models = {model_name: SklearnClassifier(model=model) for model_name, model in model_list.items()} |
| 146 | + unclipped_models = {model_name: model for model_name, model in model_list.items()} |
146 | 147 |
|
147 | 148 | (x_train_iris, y_train_iris), (_, _), _, _ = load_dataset("iris") |
148 | 149 |
|
| 150 | + y_train_iris = np.argmax(y_train_iris, axis=1) |
| 151 | + |
| 152 | + print(sklearn.__version__) |
| 153 | + dfg |
| 154 | + |
149 | 155 | for model_name, model in clipped_models.items(): |
150 | | - model.fit(x=x_train_iris, y=y_train_iris) |
| 156 | + model.fit(X=x_train_iris, y=y_train_iris) |
151 | 157 | pickle.dump( |
152 | 158 | model, |
153 | 159 | open( |
154 | 160 | os.path.join( |
155 | 161 | os.path.dirname(os.path.dirname(__file__)), |
156 | | - "utils/resources/models/scikit/", |
157 | | - model_name + "iris_clipped.sav", |
| 162 | + "resources/models/scikit/", |
| 163 | + "scikit-" + model_name + "-iris-clipped-ge-1.3.0.pickle", |
158 | 164 | ), |
159 | 165 | "wb", |
160 | 166 | ), |
161 | 167 | ) |
162 | 168 |
|
163 | 169 | for model_name, model in unclipped_models.items(): |
164 | | - model.fit(x=x_train_iris, y=y_train_iris) |
| 170 | + model.fit(X=x_train_iris, y=y_train_iris) |
165 | 171 | pickle.dump( |
166 | 172 | model, |
167 | 173 | open( |
168 | 174 | os.path.join( |
169 | 175 | os.path.dirname(os.path.dirname(__file__)), |
170 | | - "utils/resources/models/scikit/", |
171 | | - model_name + "iris_unclipped.sav", |
| 176 | + "resources/models/scikit/", |
| 177 | + "scikit-" + model_name + "-iris-unclipped-ge-1.3.0.pickle", |
172 | 178 | ), |
173 | 179 | "wb", |
174 | 180 | ), |
175 | 181 | ) |
176 | 182 |
|
177 | 183 |
|
178 | 184 | if __name__ == "__main__": |
179 | | - main_mnist_binary() |
| 185 | + # main_mnist_binary() |
180 | 186 | create_scikit_model_weights() |
181 | | - main_diabetes() |
| 187 | + # main_diabetes() |
0 commit comments