Skip to content

Commit c378dbd

Browse files
Update llm_utils.py
Move Perplexity to cerebrosllmutils
1 parent 67c74e3 commit c378dbd

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,35 @@ def get_config(self):
278278
def from_config(cls, config):
279279
# Keras handles nested layer restoration automatically
280280
return cls(**config)
281+
282+
283+
@tf.keras.utils.register_keras_serializable()
284+
class Perplexity(tf.keras.metrics.Metric):
285+
"""
286+
Computes perplexity, defined as e^(categorical crossentropy).
287+
"""
288+
289+
def __init__(self, name='perplexity', **kwargs):
290+
super().__init__(name=name, **kwargs)
291+
self.total_crossentropy = self.add_weight(name='total_crossentropy', initializer='zeros')
292+
self.count = self.add_weight(name='count', initializer='zeros')
293+
294+
def update_state(self, y_true, y_pred, sample_weight=None):
295+
# Calculate categorical crossentropy
296+
crossentropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
297+
298+
# Update the running sum of crossentropy and the count of samples
299+
self.total_crossentropy.assign_add(tf.reduce_sum(crossentropy))
300+
self.count.assign_add(tf.cast(tf.shape(y_true)[0], dtype=tf.float32))
301+
302+
def result(self):
303+
# Compute the average crossentropy
304+
average_crossentropy = self.total_crossentropy / self.count
305+
# Compute perplexity as e^(average crossentropy)
306+
return tf.exp(average_crossentropy)
307+
308+
def reset_state(self):
309+
# Reset the state variables
310+
self.total_crossentropy.assign(0.0)
311+
self.count.assign(0.0)
312+

0 commit comments

Comments
 (0)