|
46 | 46 |
|
47 | 47 | # TODO: add a better formatter for the logger. Eliminate date
|
48 | 48 |
|
49 |
| -def _create_mlp_model(input_dim: int, model_name: str) -> Model: |
50 |
| - |
51 |
| - # Define a small DNN with one hidden layer |
52 |
| - base_model = Sequential(name=model_name, layers=[ |
53 |
| - Dense(64, activation="relu", name="input_layer", input_shape=(input_dim,)), # FIXME: 28, 1, 1 to input_dim, |
54 |
| - Dense(64, activation="relu", name="hidden_layer"), |
55 |
| - Dense(1, activation="sigmoid", name="output_layer") |
56 |
| - ]) |
57 |
| - base_model.compile(optimizer="adam", |
58 |
| - loss="binary_crossentropy", |
59 |
| - metrics=["accuracy", |
60 |
| - Precision(name="precision"), |
61 |
| - Recall(name="recall"), |
62 |
| - AUC(name="auc")]) |
63 |
| - |
64 |
| - base_model.summary(print_fn=logger.info) |
65 |
| - return base_model |
66 |
| - |
67 |
| - |
68 |
| -def train_art_keras_classifier(x_train: Union[pd.DataFrame, np.ndarray] , y_train: Union[pd.DataFrame, np.ndarray], model_name: str) -> KerasClassifier: |
69 |
| - """Trains a KerasClassifier using the ART wrapper.""" |
70 |
| - |
71 |
| - # Create the Keras model |
72 |
| - mlp_model = _create_mlp_model(x_train.shape[1], model_name) |
73 |
| - |
74 |
| - # Create the ART KerasClassifier wrapper |
75 |
| - mlp_classifier = TensorFlowV2Classifier( |
76 |
| - model=mlp_model, |
77 |
| - loss_object=tf.keras.losses.CategoricalCrossentropy(from_logits=False), |
78 |
| - optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), |
79 |
| - nb_classes=2, # Ensure this is set correctly |
80 |
| - input_shape=(x_train.shape[1],) |
81 |
| - ) |
82 |
| - |
83 |
| - |
84 |
| - # Requires ndarrays, so the dataframes are transformed |
85 |
| - x_values = x_train.values if type(x_train) == pd.DataFrame else x_train |
86 |
| - y_values = np.squeeze(y_train.values) if type(y_train) == pd.DataFrame else y_train |
87 |
| - |
88 |
| - # Train the model |
89 |
| - mlp_classifier.fit(x_values, y_values, batch_size=512, nb_epochs=100, verbose=True) |
90 |
| - |
91 |
| - return mlp_classifier |
92 |
| - |
93 | 49 | class MockClusterer(ClusterMixin):
|
94 | 50 | """
|
95 | 51 | A mock ClusterMixin for testing purposes. This avoids using a real clustering
|
|
0 commit comments