Skip to content

Commit ec46b59

Browse files
Update llm_utils.py
Fix the SparsePerplexity.
1 parent 1f64801 commit ec46b59

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,11 @@ def reset_state(self):
289289
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='SparsePerplexity')
290290
class SparsePerplexity(tf.keras.metrics.Metric):
291291
"""
292-
Computes perplexity, defined as e^(sparse categorical crossentropy).
293-
Assumes y_true are integer labels (not one-hot encoded).
292+
Computes perplexity for a batch of next-token predictions.
293+
294+
Expects:
295+
y_true: (Batch_Size,) - Integer labels (the actual next token).
296+
y_pred: (Batch_Size, Vocab_Size) - Logits/Probabilities for the next token.
294297
"""
295298

296299
def __init__(self, name='perplexity', **kwargs):
@@ -299,35 +302,41 @@ def __init__(self, name='perplexity', **kwargs):
299302
self.count = self.add_weight(name='count', initializer='zeros')
300303

301304
def update_state(self, y_true, y_pred, sample_weight=None):
305+
# y_true shape: (Batch_Size,)
306+
# y_pred shape: (Batch_Size, Vocab_Size)
307+
302308
# Calculate sparse categorical crossentropy
303-
# This function expects y_true to be integers and y_pred to be probabilities/logits
304-
crossentropy = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
305-
306-
# Apply sample weighting if provided
309+
# from_logits=True is safer for raw model outputs.
310+
# If your final layer is Softmax, change to False.
311+
crossentropy = tf.keras.losses.sparse_categorical_crossentropy(
312+
y_true,
313+
y_pred,
314+
from_logits=True
315+
)
316+
317+
# Handle sample weighting
307318
if sample_weight is not None:
308-
# Ensure sample_weight is float32 for multiplication
309319
sample_weight = tf.cast(sample_weight, tf.float32)
310320
crossentropy = crossentropy * sample_weight
311-
# If sample_weight is used, we sum the weights to get the correct average
312321
batch_weight_sum = tf.reduce_sum(sample_weight)
313322
else:
314-
# If no sample_weight, the count is the batch size
323+
# Count is the Batch Size
315324
batch_weight_sum = tf.cast(tf.shape(y_true)[0], dtype=tf.float32)
316325

317-
# Update the running sum of crossentropy and the count of samples
326+
# Update the running sum of crossentropy
318327
self.total_crossentropy.assign_add(tf.reduce_sum(crossentropy))
328+
329+
# Update the running count
319330
self.count.assign_add(batch_weight_sum)
320331

321332
def result(self):
322333
# Compute the average crossentropy
323-
# Avoid division by zero
324334
average_crossentropy = tf.math.divide_no_nan(self.total_crossentropy, self.count)
325-
335+
326336
# Compute perplexity as e^(average crossentropy)
327337
return tf.exp(average_crossentropy)
328338

329339
def reset_state(self):
330-
# Reset the state variables
331340
self.total_crossentropy.assign(0.0)
332341
self.count.assign(0.0)
333342

0 commit comments

Comments
 (0)