@@ -531,6 +531,37 @@ def from_config(cls, config):
531531
532532meta_trial_number = 42 # irrelevant unless in distributed training
533533
534+ # Custom metric: Perplexity:
535+
536+ class Perplexity (tf .keras .metrics .Metric ):
537+ """
538+ Computes perplexity, defined as e^(categorical crossentropy).
539+ """
540+ def __init__ (self , name = 'perplexity' , ** kwargs ):
541+ super ().__init__ (name = name , ** kwargs )
542+ self .total_crossentropy = self .add_weight (name = 'total_crossentropy' , initializer = 'zeros' )
543+ self .count = self .add_weight (name = 'count' , initializer = 'zeros' )
544+
545+ def update_state (self , y_true , y_pred , sample_weight = None ):
546+ # Calculate categorical crossentropy
547+ crossentropy = tf .keras .losses .categorical_crossentropy (y_true , y_pred )
548+
549+ # Update the running sum of crossentropy and the count of samples
550+ self .total_crossentropy .assign_add (tf .reduce_sum (crossentropy ))
551+ self .count .assign_add (tf .cast (tf .shape (y_true )[0 ], dtype = tf .float32 ))
552+
553+ def result (self ):
554+ # Compute the average crossentropy
555+ average_crossentropy = self .total_crossentropy / self .count
556+ # Compute perplexity as e^(average crossentropy)
557+ return tf .exp (average_crossentropy )
558+
559+ def reset_state (self ):
560+ # Reset the state variables
561+ self .total_crossentropy .assign (0.0 )
562+ self .count .assign (0.0 )
563+
564+ perplexity_metric = Perplexity ()
534565
535566cerebros_automl = SimpleCerebrosRandomSearch (
536567 unit_type = DenseUnit ,
@@ -568,7 +599,7 @@ def from_config(cls, config):
568599 learning_rate = learning_rate ,
569600 loss = tf .keras .losses .CategoricalCrossentropy (),
570601 metrics = [tf .keras .metrics .CategoricalAccuracy (),
571- # tf.keras.metrics.Perplexity(name='perplexity') ,
602+ perplexity_metric ,
572603 # tf.keras.metrics.Accuracy()
573604 ],
574605 epochs = epochs ,
0 commit comments