@@ -278,3 +278,35 @@ def get_config(self):
278278 def from_config (cls , config ):
279279 # Keras handles nested layer restoration automatically
280280 return cls (** config )
281+
282+
283+ @tf .keras .utils .register_keras_serializable ()
284+ class Perplexity (tf .keras .metrics .Metric ):
285+ """
286+ Computes perplexity, defined as e^(categorical crossentropy).
287+ """
288+
289+ def __init__ (self , name = 'perplexity' , ** kwargs ):
290+ super ().__init__ (name = name , ** kwargs )
291+ self .total_crossentropy = self .add_weight (name = 'total_crossentropy' , initializer = 'zeros' )
292+ self .count = self .add_weight (name = 'count' , initializer = 'zeros' )
293+
294+ def update_state (self , y_true , y_pred , sample_weight = None ):
295+ # Calculate categorical crossentropy
296+ crossentropy = tf .keras .losses .categorical_crossentropy (y_true , y_pred )
297+
298+ # Update the running sum of crossentropy and the count of samples
299+ self .total_crossentropy .assign_add (tf .reduce_sum (crossentropy ))
300+ self .count .assign_add (tf .cast (tf .shape (y_true )[0 ], dtype = tf .float32 ))
301+
302+ def result (self ):
303+ # Compute the average crossentropy
304+ average_crossentropy = self .total_crossentropy / self .count
305+ # Compute perplexity as e^(average crossentropy)
306+ return tf .exp (average_crossentropy )
307+
308+ def reset_state (self ):
309+ # Reset the state variables
310+ self .total_crossentropy .assign (0.0 )
311+ self .count .assign (0.0 )
312+
0 commit comments