@@ -195,7 +195,11 @@ def main(argv):
195195 tf .keras .metrics .SparseCategoricalAccuracy ())
196196 corrupt_metrics ['test/ece_{}' .format (
197197 name )] = ed .metrics .ExpectedCalibrationError (num_bins = FLAGS .num_bins )
198-
198+ test_diversity = {
199+ 'test/disagreement' : tf .keras .metrics .Mean (),
200+ 'test/average_kl' : tf .keras .metrics .Mean (),
201+ 'test/cosine_similarity' : tf .keras .metrics .Mean (),
202+ }
199203 # Evaluate model predictions.
200204 for n , (name , test_dataset ) in enumerate (test_datasets .items ()):
201205 logits_dataset = []
@@ -214,6 +218,10 @@ def main(argv):
214218 negative_log_likelihood = tf .reduce_mean (
215219 ensemble_negative_log_likelihood (labels , logits ))
216220 per_probs = tf .nn .softmax (logits )
221+ diversity_results = ed .metrics .average_pairwise_diversity (
222+ per_probs , ensemble_size )
223+ for k , v in diversity_results .items ():
224+ test_diversity ['test/' + k ].update_state (v )
217225 probs = tf .reduce_mean (per_probs , axis = 0 )
218226 if name == 'clean' :
219227 gibbs_ce = tf .reduce_mean (gibbs_cross_entropy (labels , logits ))
@@ -234,11 +242,15 @@ def main(argv):
234242 (n + 1 ) / num_datasets , n + 1 , num_datasets ))
235243 logging .info (message )
236244
245+ total_metrics = metrics .copy ()
246+ total_metrics .update (test_diversity )
237247 corrupt_results = utils .aggregate_corrupt_metrics (corrupt_metrics ,
238248 corruption_types ,
239249 max_intensity ,
240250 FLAGS .alexnet_errors_path )
241- total_results = {name : metric .result () for name , metric in metrics .items ()}
251+ total_results = {
252+ name : metric .result () for name , metric in total_metrics .items ()
253+ }
242254 total_results .update (corrupt_results )
243255 logging .info ('Metrics: %s' , total_results )
244256
0 commit comments