Skip to content

Commit 7db4b15

Browse files
Update phishing_email_detection_gpt2.py
First attempt to integrate a rotary positional embedding.
1 parent 713ac96 commit 7db4b15

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

phishing_email_detection_gpt2.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,38 @@ def get_config(self):
208208
def from_config(cls, config):
209209
return cls(max_seq_length=config['max_seq_length'])
210210

211+
212+
class RotaryPositionEmbedding(tf.keras.layers.Layer):
213+
def __init__(self, max_seq_length, d_model, **kwargs):
214+
super().__init__(**kwargs)
215+
self.max_seq_length = max_seq_length
216+
self.d_model = d_model
217+
assert d_model % 2 == 0, "d_model must be even"
218+
219+
# Precompute rotation matrices
220+
inv_freq = 1.0 / (10000 ** (tf.range(0, d_model, 2, dtype=tf.float32) / d_model))
221+
positions = tf.range(max_seq_length, dtype=tf.float32)
222+
sinusoid = tf.einsum('i,j->ij', positions, inv_freq)
223+
224+
self.sin = tf.sin(sinusoid)
225+
self.cos = tf.cos(sinusoid)
226+
227+
def call(self, x):
228+
batch_size = tf.shape(x)[0]
229+
seq_len = tf.shape(x)[1]
230+
231+
# Split dimensions into pairs
232+
x = tf.reshape(x, [batch_size, seq_len, self.d_model//2, 2])
233+
234+
# Apply rotation
235+
x_rot = tf.stack([
236+
x[..., 0] * self.cos[:seq_len] - x[..., 1] * self.sin[:seq_len],
237+
x[..., 0] * self.sin[:seq_len] + x[..., 1] * self.cos[:seq_len]
238+
], axis=-1)
239+
240+
return tf.reshape(x_rot, [batch_size, seq_len, self.d_model])
241+
242+
211243
# GPT2 configurables
212244

213245
# Optimal for accuracy thus far:
@@ -221,17 +253,19 @@ def from_config(cls, config):
221253

222254
# On larger hardware, this could probably be increased considerably and
223255
# Probably would improve performance ...
224-
EMBEDDING_DIM = 23 # Define EMBEDDING_DIM here, to match your embedding layer.
256+
EMBEDDING_N = 12 # Define EMBEDDING_DIM here, to match your embedding layer.
257+
EMBEDDING_DIM = int(EMBEDDING_N * 2)
225258

226259
embedded = tf.keras.layers.Embedding(
227260
input_dim=VOCABULARY_SIZE,
228261
output_dim=EMBEDDING_DIM,
229262
input_length=max_seq_length,
230263
mask_zero=True)(tokens)
231264

232-
position_embedding = PositionEmbedding(
265+
position_embedding = RotaryPositionEmbedding(
233266
sequence_length=max_seq_length,
234-
initializer="uniform",
267+
sequence_length=EMBEDDING_DIM
268+
# initializer="uniform",
235269
)(embedded)
236270

237271
# As an FYI, we tried an add layer both with and without

0 commit comments

Comments
 (0)