Skip to content

Commit 552e131

Browse files
Update phishing_email_detection_gpt2.py
Make the custom metric Perplexity searializable.
1 parent edf5100 commit 552e131

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

phishing_email_detection_gpt2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def from_config(cls, config):
430430

431431
# iRoPE helper functions
432432

433+
@tf.keras.utils.register_keras_serializable()
433434
def split_alternate(x):
434435
shape = tf.shape(x)
435436
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])
@@ -438,13 +439,15 @@ def split_alternate(x):
438439
return x
439440

440441

442+
@tf.keras.utils.register_keras_serializable()
441443
def rotate_half(x):
442444
x = split_alternate(x)
443445
d = tf.shape(x)[-1]
444446
rotated_x = tf.concat([-x[..., d//2:], x[..., :d//2]], axis=-1)
445447
return tf.reshape(rotated_x, tf.shape(x))
446448

447449

450+
@tf.keras.utils.register_keras_serializable()
448451
def apply_rotary_pos_emb(x, sin, cos):
449452
cos = tf.reshape(cos, [tf.shape(cos)[0], tf.shape(cos)[1], -1])
450453
sin = tf.reshape(sin, [tf.shape(sin)[0], tf.shape(sin)[1], -1])
@@ -533,6 +536,7 @@ def from_config(cls, config):
533536

534537
# Custom metric: Perplexity:
535538

539+
@tf.keras.utils.register_keras_serializable()
536540
class Perplexity(tf.keras.metrics.Metric):
537541
"""
538542
Computes perplexity, defined as e^(categorical crossentropy).

0 commit comments

Comments
 (0)