Skip to content

Commit 862f338

Browse files
GhassenJedward-bot
authored andcommitted
Track corrupted output_similarity.
PiperOrigin-RevId: 317342104
1 parent 990e3e7 commit 862f338

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

baselines/cifar/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def aggregate_corrupt_metrics(metrics,
283283
Dictionary of aggregated results.
284284
285285
"""
286-
diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl']
286+
diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl',
287+
'outputs_similarity']
287288
results = {
288289
'test/nll_mean_corrupted': 0.,
289290
'test/accuracy_mean_corrupted': 0.,
@@ -305,7 +306,7 @@ def aggregate_corrupt_metrics(metrics,
305306
disagreement = np.zeros(len(corruption_types))
306307
cosine_similarity = np.zeros(len(corruption_types))
307308
average_kl = np.zeros(len(corruption_types))
308-
309+
outputs_similarity = np.zeros(len(corruption_types))
309310
for i in range(len(corruption_types)):
310311
dataset_name = '{0}_{1}'.format(corruption_types[i], intensity)
311312
nll[i] = metrics['test/nll_{}'.format(dataset_name)].result()
@@ -321,17 +322,21 @@ def aggregate_corrupt_metrics(metrics,
321322
dataset_name)].result()
322323
member_ece[i] = 0.
323324
if corrupt_diversity is not None:
325+
error = 1 - acc[i] + tf.keras.backend.epsilon()
324326
disagreement[i] = (
325327
corrupt_diversity['corrupt_diversity/disagreement_{}'.format(
326-
dataset_name)].result())
328+
dataset_name)].result()) / error
327329
# Normalize the corrupt disagreement by its error rate.
328-
error = 1 - acc[i] + tf.keras.backend.epsilon()
329330
cosine_similarity[i] = (
330331
corrupt_diversity['corrupt_diversity/cosine_similarity_{}'.format(
331-
dataset_name)].result()) / error
332+
dataset_name)].result())
332333
average_kl[i] = (
333334
corrupt_diversity['corrupt_diversity/average_kl_{}'.format(
334335
dataset_name)].result())
336+
outputs_similarity[i] = (
337+
corrupt_diversity['corrupt_diversity/outputs_similarity_{}'.format(
338+
dataset_name)].result())
339+
335340
if log_fine_metrics or output_dir is not None:
336341
fine_metrics_results['test/nll_{}'.format(dataset_name)] = nll[i]
337342
fine_metrics_results['test/accuracy_{}'.format(dataset_name)] = acc[i]
@@ -343,6 +348,9 @@ def aggregate_corrupt_metrics(metrics,
343348
dataset_name)] = cosine_similarity[i]
344349
fine_metrics_results['corrupt_diversity/average_kl_{}'.format(
345350
dataset_name)] = average_kl[i]
351+
fine_metrics_results['corrupt_diversity/outputs_similarity_{}'.format(
352+
dataset_name)] = outputs_similarity[i]
353+
346354
avg_nll = np.mean(nll)
347355
avg_accuracy = np.mean(acc)
348356
avg_ece = np.mean(ece)
@@ -363,7 +371,7 @@ def aggregate_corrupt_metrics(metrics,
363371
results['test/member_ece_mean_corrupted'] += avg_member_ece
364372
if corrupt_diversity is not None:
365373
avg_diversity_metrics = [np.mean(disagreement), np.mean(
366-
cosine_similarity), np.mean(average_kl)]
374+
cosine_similarity), np.mean(average_kl), np.mean(outputs_similarity)]
367375
for key, avg in zip(diversity_keys, avg_diversity_metrics):
368376
results['corrupt_diversity/{}_mean_{}'.format(
369377
key, intensity)] = avg

0 commit comments

Comments
 (0)