@@ -289,8 +289,11 @@ def reset_state(self):
289289@tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'SparsePerplexity' )
290290class SparsePerplexity (tf .keras .metrics .Metric ):
291291 """
292- Computes perplexity, defined as e^(sparse categorical crossentropy).
293- Assumes y_true are integer labels (not one-hot encoded).
292+ Computes perplexity for a batch of next-token predictions.
293+
294+ Expects:
295+ y_true: (Batch_Size,) - Integer labels (the actual next token).
296+ y_pred: (Batch_Size, Vocab_Size) - Logits/Probabilities for the next token.
294297 """
295298
296299 def __init__ (self , name = 'perplexity' , ** kwargs ):
@@ -299,35 +302,41 @@ def __init__(self, name='perplexity', **kwargs):
299302 self .count = self .add_weight (name = 'count' , initializer = 'zeros' )
300303
301304 def update_state (self , y_true , y_pred , sample_weight = None ):
305+ # y_true shape: (Batch_Size,)
306+ # y_pred shape: (Batch_Size, Vocab_Size)
307+
302308 # Calculate sparse categorical crossentropy
303- # This function expects y_true to be integers and y_pred to be probabilities/logits
304- crossentropy = tf .keras .losses .sparse_categorical_crossentropy (y_true , y_pred )
305-
306- # Apply sample weighting if provided
309+ # from_logits=True is safer for raw model outputs.
310+ # If your final layer is Softmax, change to False.
311+ crossentropy = tf .keras .losses .sparse_categorical_crossentropy (
312+ y_true ,
313+ y_pred ,
314+ from_logits = True
315+ )
316+
317+ # Handle sample weighting
307318 if sample_weight is not None :
308- # Ensure sample_weight is float32 for multiplication
309319 sample_weight = tf .cast (sample_weight , tf .float32 )
310320 crossentropy = crossentropy * sample_weight
311- # If sample_weight is used, we sum the weights to get the correct average
312321 batch_weight_sum = tf .reduce_sum (sample_weight )
313322 else :
314- # If no sample_weight, the count is the batch size
323+ # Count is the Batch Size
315324 batch_weight_sum = tf .cast (tf .shape (y_true )[0 ], dtype = tf .float32 )
316325
317- # Update the running sum of crossentropy and the count of samples
326+ # Update the running sum of crossentropy
318327 self .total_crossentropy .assign_add (tf .reduce_sum (crossentropy ))
328+
329+ # Update the running count
319330 self .count .assign_add (batch_weight_sum )
320331
321332 def result (self ):
322333 # Compute the average crossentropy
323- # Avoid division by zero
324334 average_crossentropy = tf .math .divide_no_nan (self .total_crossentropy , self .count )
325-
335+
326336 # Compute perplexity as e^(average crossentropy)
327337 return tf .exp (average_crossentropy )
328338
329339 def reset_state (self ):
330- # Reset the state variables
331340 self .total_crossentropy .assign (0.0 )
332341 self .count .assign (0.0 )
333342
0 commit comments