Skip to content

Commit f88afbd

Browse files
Update phishing_email_detection_gpt2.py
1 parent 3bd57f3 commit f88afbd

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

phishing_email_detection_gpt2.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def from_config(cls, config):
209209
return cls(max_seq_length=config['max_seq_length'])
210210

211211

212+
212213
class RotaryPositionEmbedding(tf.keras.layers.Layer):
213214
def __init__(self, max_seq_length, d_model, **kwargs):
214215
super().__init__(**kwargs)
@@ -218,32 +219,36 @@ def __init__(self, max_seq_length, d_model, **kwargs):
218219

219220
# Precompute rotation matrices
220221
inv_freq = 1.0 / (10000 ** (tf.range(0, d_model, 2, dtype=tf.float32) / d_model))
222+
self.inv_freq = tf.cast(inv_freq, tf.float32)
221223
positions = tf.range(max_seq_length, dtype=tf.float32)
222-
sinusoid = tf.einsum('i,j->ij', positions, inv_freq)
223-
224-
self.sin = tf.sin(sinusoid)
225-
self.cos = tf.cos(sinusoid)
224+
self.sin = tf.sin(tf.einsum('i,j->ij', positions, inv_freq))
225+
self.cos = tf.cos(tf.einsum('i,j->ij', positions, inv_freq))
226226

227227
def call(self, x):
228228
batch_size = tf.shape(x)[0]
229229
seq_len = tf.shape(x)[1]
230230

231-
# Split dimensions into pairs
232-
x = tf.reshape(x, [batch_size, seq_len, self.d_model//2, 2])
231+
# Compute sine and cosine matrices for current sequence length
232+
sinusoid = tf.einsum('i,j->ij', tf.range(seq_len, dtype=tf.float32), self.inv_freq)
233+
current_sin = tf.sin(sinusoid)
234+
current_cos = tf.cos(sinusoid)
233235

234-
# Apply rotation
235-
x_rot = tf.stack([
236-
x[..., 0] * self.cos[:seq_len] - x[..., 1] * self.sin[:seq_len],
237-
x[..., 0] * self.sin[:seq_len] + x[..., 1] * self.cos[:seq_len]
236+
# Split dimensions and apply rotation using einsum
237+
x = tf.reshape(x, [batch_size, seq_len, self.d_model//2, 2])
238+
rotated = tf.stack([
239+
x[..., 0] * current_cos - x[..., 1] * current_sin,
240+
x[..., 0] * current_sin + x[..., 1] * current_cos
238241
], axis=-1)
239242

240-
return tf.reshape(x_rot, [batch_size, seq_len, self.d_model])
243+
# Reshape back and apply dropout
244+
return tf.reshape(rotated, [batch_size, seq_len, self.d_model])
245+
246+
241247

242248

243249
# GPT2 configurables
244250

245251
# Optimal for accuracy thus far:
246-
# max_seq_length = 900
247252
max_seq_length = 1024
248253

249254
inp = tf.keras.layers.Input(shape=(), dtype=tf.string)

0 commit comments

Comments
 (0)