Skip to content

Commit d2d0f72

Browse files
Update phishing_email_detection_gpt2.py
...
1 parent 4315c51 commit d2d0f72

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

phishing_email_detection_gpt2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def from_config(cls, config):
211211

212212

213213

214-
class RotaryEmbedding(tf.keras.layers.Layer):
214+
def RotaryEmbedding(tf.keras.layers.Layer):
215215
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
216216
super().__init__(**kwargs)
217217
self.dim = dim
@@ -233,11 +233,18 @@ def call(self, x, seq_len=None):
233233
seq_len = tf.shape(x)[1] if seq_len is None else seq_len
234234
sin = self.sin_cache[:seq_len]
235235
cos = self.cos_cache[:seq_len]
236-
sin = tf.cast(tf.repeat(sin[..., tf.newaxis], self.dim // 2, axis=-1), x.dtype)
237-
cos = tf.cast(tf.repeat(cos[..., tf.newaxis], self.dim // 2, axis=-1), x.dtype)
236+
sin = tf.cast(tf.repeat(sin[..., tf.newaxis], 2, axis=-1), x.dtype)
237+
cos = tf.cast(tf.repeat(cos[..., tf.newaxis], 2, axis=-1), x.dtype)
238+
sin = tf.reshape(sin, [seq_len, self.dim])
239+
cos = tf.reshape(cos, [seq_len, self.dim])
240+
sin = tf.expand_dims(sin, axis=0)
241+
cos = tf.expand_dims(cos, axis=0)
242+
sin = tf.tile(sin, [batch_size, 1, 1])
243+
cos = tf.tile(cos, [batch_size, 1, 1])
238244
return sin, cos
239245

240246

247+
241248
def split_alternate(x):
242249
shape = tf.shape(x)
243250
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])

0 commit comments

Comments
 (0)