Skip to content

Commit 300ed89

Browse files
committed
[WIP] unit test cleanup
Signed-off-by: alvaro <[email protected]>
1 parent 50cfc7e commit 300ed89

File tree

1 file changed

+0
-44
lines changed

1 file changed

+0
-44
lines changed

tests/defences/detector/poison/test_clustering_centroid_analysis.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,50 +46,6 @@
4646

4747
# TODO: add a better formatter for the logger. Eliminate date
4848

49-
def _create_mlp_model(input_dim: int, model_name: str) -> Model:
50-
51-
# Define a small DNN with one hidden layer
52-
base_model = Sequential(name=model_name, layers=[
53-
Dense(64, activation="relu", name="input_layer", input_shape=(input_dim,)), # FIXME: 28, 1, 1 to input_dim,
54-
Dense(64, activation="relu", name="hidden_layer"),
55-
Dense(1, activation="sigmoid", name="output_layer")
56-
])
57-
base_model.compile(optimizer="adam",
58-
loss="binary_crossentropy",
59-
metrics=["accuracy",
60-
Precision(name="precision"),
61-
Recall(name="recall"),
62-
AUC(name="auc")])
63-
64-
base_model.summary(print_fn=logger.info)
65-
return base_model
66-
67-
68-
def train_art_keras_classifier(x_train: Union[pd.DataFrame, np.ndarray] , y_train: Union[pd.DataFrame, np.ndarray], model_name: str) -> KerasClassifier:
69-
"""Trains a KerasClassifier using the ART wrapper."""
70-
71-
# Create the Keras model
72-
mlp_model = _create_mlp_model(x_train.shape[1], model_name)
73-
74-
# Create the ART KerasClassifier wrapper
75-
mlp_classifier = TensorFlowV2Classifier(
76-
model=mlp_model,
77-
loss_object=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
78-
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
79-
nb_classes=2, # Ensure this is set correctly
80-
input_shape=(x_train.shape[1],)
81-
)
82-
83-
84-
# Requires ndarrays, so the dataframes are transformed
85-
x_values = x_train.values if type(x_train) == pd.DataFrame else x_train
86-
y_values = np.squeeze(y_train.values) if type(y_train) == pd.DataFrame else y_train
87-
88-
# Train the model
89-
mlp_classifier.fit(x_values, y_values, batch_size=512, nb_epochs=100, verbose=True)
90-
91-
return mlp_classifier
92-
9349
class MockClusterer(ClusterMixin):
9450
"""
9551
A mock ClusterMixin for testing purposes. This avoids using a real clustering

0 commit comments

Comments
 (0)