Skip to content

Commit ac6c295

Browse files
Update phishing_email_detection_gpt2.py
Try manually creating a custom perplexity metric ....
1 parent 5c1ef56 commit ac6c295

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

phishing_email_detection_gpt2.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,37 @@ def from_config(cls, config):
531531

532532
meta_trial_number = 42 # irrelevant unless in distributed training
533533

534+
# Custom metric: Perplexity:
535+
536+
class Perplexity(tf.keras.metrics.Metric):
537+
"""
538+
Computes perplexity, defined as e^(categorical crossentropy).
539+
"""
540+
def __init__(self, name='perplexity', **kwargs):
541+
super().__init__(name=name, **kwargs)
542+
self.total_crossentropy = self.add_weight(name='total_crossentropy', initializer='zeros')
543+
self.count = self.add_weight(name='count', initializer='zeros')
544+
545+
def update_state(self, y_true, y_pred, sample_weight=None):
546+
# Calculate categorical crossentropy
547+
crossentropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
548+
549+
# Update the running sum of crossentropy and the count of samples
550+
self.total_crossentropy.assign_add(tf.reduce_sum(crossentropy))
551+
self.count.assign_add(tf.cast(tf.shape(y_true)[0], dtype=tf.float32))
552+
553+
def result(self):
554+
# Compute the average crossentropy
555+
average_crossentropy = self.total_crossentropy / self.count
556+
# Compute perplexity as e^(average crossentropy)
557+
return tf.exp(average_crossentropy)
558+
559+
def reset_state(self):
560+
# Reset the state variables
561+
self.total_crossentropy.assign(0.0)
562+
self.count.assign(0.0)
563+
564+
perplexity_metric = Perplexity()
534565

535566
cerebros_automl = SimpleCerebrosRandomSearch(
536567
unit_type=DenseUnit,
@@ -568,7 +599,7 @@ def from_config(cls, config):
568599
learning_rate=learning_rate,
569600
loss=tf.keras.losses.CategoricalCrossentropy(),
570601
metrics=[tf.keras.metrics.CategoricalAccuracy(),
571-
# tf.keras.metrics.Perplexity(name='perplexity'),
602+
perplexity_metric,
572603
# tf.keras.metrics.Accuracy()
573604
],
574605
epochs=epochs,

0 commit comments

Comments
 (0)