@@ -283,7 +283,8 @@ def aggregate_corrupt_metrics(metrics,
283283 Dictionary of aggregated results.
284284
285285 """
286- diversity_keys = ['disagreement' , 'cosine_similarity' , 'average_kl' ]
286+ diversity_keys = ['disagreement' , 'cosine_similarity' , 'average_kl' ,
287+ 'outputs_similarity' ]
287288 results = {
288289 'test/nll_mean_corrupted' : 0. ,
289290 'test/accuracy_mean_corrupted' : 0. ,
@@ -305,7 +306,7 @@ def aggregate_corrupt_metrics(metrics,
305306 disagreement = np .zeros (len (corruption_types ))
306307 cosine_similarity = np .zeros (len (corruption_types ))
307308 average_kl = np .zeros (len (corruption_types ))
308-
309+ outputs_similarity = np . zeros ( len ( corruption_types ))
309310 for i in range (len (corruption_types )):
310311 dataset_name = '{0}_{1}' .format (corruption_types [i ], intensity )
311312 nll [i ] = metrics ['test/nll_{}' .format (dataset_name )].result ()
@@ -321,17 +322,21 @@ def aggregate_corrupt_metrics(metrics,
321322 dataset_name )].result ()
322323 member_ece [i ] = 0.
323324 if corrupt_diversity is not None :
325+ error = 1 - acc [i ] + tf .keras .backend .epsilon ()
324326 disagreement [i ] = (
325327 corrupt_diversity ['corrupt_diversity/disagreement_{}' .format (
326- dataset_name )].result ())
328+ dataset_name )].result ()) / error
327329 # Normalize the corrupt disagreement by its error rate.
328- error = 1 - acc [i ] + tf .keras .backend .epsilon ()
329330 cosine_similarity [i ] = (
330331 corrupt_diversity ['corrupt_diversity/cosine_similarity_{}' .format (
331- dataset_name )].result ()) / error
332+ dataset_name )].result ())
332333 average_kl [i ] = (
333334 corrupt_diversity ['corrupt_diversity/average_kl_{}' .format (
334335 dataset_name )].result ())
336+ outputs_similarity [i ] = (
337+ corrupt_diversity ['corrupt_diversity/outputs_similarity_{}' .format (
338+ dataset_name )].result ())
339+
335340 if log_fine_metrics or output_dir is not None :
336341 fine_metrics_results ['test/nll_{}' .format (dataset_name )] = nll [i ]
337342 fine_metrics_results ['test/accuracy_{}' .format (dataset_name )] = acc [i ]
@@ -343,6 +348,9 @@ def aggregate_corrupt_metrics(metrics,
343348 dataset_name )] = cosine_similarity [i ]
344349 fine_metrics_results ['corrupt_diversity/average_kl_{}' .format (
345350 dataset_name )] = average_kl [i ]
351+ fine_metrics_results ['corrupt_diversity/outputs_similarity_{}' .format (
352+ dataset_name )] = outputs_similarity [i ]
353+
346354 avg_nll = np .mean (nll )
347355 avg_accuracy = np .mean (acc )
348356 avg_ece = np .mean (ece )
@@ -363,7 +371,7 @@ def aggregate_corrupt_metrics(metrics,
363371 results ['test/member_ece_mean_corrupted' ] += avg_member_ece
364372 if corrupt_diversity is not None :
365373 avg_diversity_metrics = [np .mean (disagreement ), np .mean (
366- cosine_similarity ), np .mean (average_kl )]
374+ cosine_similarity ), np .mean (average_kl ), np . mean ( outputs_similarity ) ]
367375 for key , avg in zip (diversity_keys , avg_diversity_metrics ):
368376 results ['corrupt_diversity/{}_mean_{}' .format (
369377 key , intensity )] = avg
0 commit comments