Skip to content

Commit 44854be

Browse files
Update phishing_email_detection_gpt2.py
Attempt to correct issue with tf v 2.19.0 graph scope.
1 parent 2c417fb commit 44854be

File tree

1 file changed

+66
-19
lines changed

1 file changed

+66
-19
lines changed

phishing_email_detection_gpt2.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,34 +255,81 @@ class RotaryEmbedding(tf.keras.layers.Layer):
255255
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
256256
super().__init__(**kwargs)
257257
self.dim = dim
258-
self.max_seq_len = max_seq_len
258+
self.max_seq_len = max_seq_len # Still useful for potential pre-allocation if needed, but not for caching tensors
259259
self.temperature = temperature
260+
# No caching in __init__ or build anymore
260261

261262
def build(self, input_shape):
263+
# Build is primarily for creating weights. We don't have trainable weights here.
264+
# We can calculate inv_freq here if desired, as it doesn't depend on input shape directly
265+
# and is constant. However, calculating it in call() is also fine.
266+
# Let's calculate it once here to avoid recomputing constants.
267+
# Ensure dim is even
268+
if self.dim % 2 != 0:
269+
raise ValueError(f"Embedding dimension `dim` ({self.dim}) must be even for RotaryEmbedding.")
270+
271+
inv_freq_base = tf.range(0, self.dim, 2, dtype=tf.float32) # Corrected range for pair dimension
272+
inv_freq = 1.0 / (self.temperature ** (inv_freq_base / self.dim)) # Corrected calculation
273+
self.inv_freq = inv_freq # Store the constant factor
262274
super().build(input_shape)
263-
inv_freq = 1.0 / (self.temperature ** (tf.range(0, self.dim // 2, dtype=tf.float32) / (self.dim // 2)))
264-
position = tf.range(self.max_seq_len, dtype=tf.float32)
265-
sinusoid = tf.einsum("i,j->ij", position, inv_freq)
266-
sin = tf.sin(sinusoid)
267-
cos = tf.cos(sinusoid)
268-
self.sin_cache = sin
269-
self.cos_cache = cos
270-
275+
271276
def call(self, x, seq_len=None):
272-
batch_size = tf.shape(x)[0]
273-
seq_len = tf.shape(x)[1] if seq_len is None else seq_len
274-
sin = self.sin_cache[:seq_len]
275-
cos = self.cos_cache[:seq_len]
276-
sin = tf.cast(tf.repeat(sin[..., tf.newaxis], 2, axis=-1), x.dtype)
277-
cos = tf.cast(tf.repeat(cos[..., tf.newaxis], 2, axis=-1), x.dtype)
278-
sin = tf.reshape(sin, [seq_len, self.dim])
279-
cos = tf.reshape(cos, [seq_len, self.dim])
280-
sin = tf.expand_dims(sin, axis=0)
281-
cos = tf.expand_dims(cos, axis=0)
277+
shape = tf.shape(x)
278+
batch_size = shape[0]
279+
# Determine sequence length dynamically from input tensor 'x'
280+
actual_seq_len = shape[1]
281+
282+
# Use actual_seq_len for calculations
283+
position = tf.range(actual_seq_len, dtype=tf.float32)
284+
# Calculate sinusoid input using einsum or broadcasting
285+
# Einsum approach:
286+
sinusoid_inp = tf.einsum("i,j->ij", position, self.inv_freq)
287+
# Broadcasting approach (might be clearer):
288+
# sinusoid_inp = tf.expand_dims(position, axis=-1) * tf.expand_dims(self.inv_freq, axis=0)
289+
290+
# Calculate sin and cos based on the actual sequence length
291+
sin = tf.sin(sinusoid_inp)
292+
cos = tf.cos(sinusoid_inp)
293+
294+
# Repeat sin/cos for interleaving: [a, b] -> [a, a, b, b]
295+
# Original code used repeat then reshape, which might be slightly different
296+
# from direct interleaving depending on interpretation. Let's stick to the
297+
# original logic's apparent intent which leads to pairing.
298+
# We need shape [actual_seq_len, dim]
299+
# sin/cos currently [actual_seq_len, dim/2]
300+
sin = tf.repeat(sin, 2, axis=-1) # Repeat along the last dimension
301+
cos = tf.repeat(cos, 2, axis=-1) # Repeat along the last dimension
302+
303+
# Expand dims for batch and tile
304+
# Output shape needs to be [batch_size, actual_seq_len, dim]
305+
sin = tf.expand_dims(sin, axis=0) # Shape [1, actual_seq_len, dim]
306+
cos = tf.expand_dims(cos, axis=0) # Shape [1, actual_seq_len, dim]
307+
308+
# Tile to match the batch size
282309
sin = tf.tile(sin, [batch_size, 1, 1])
283310
cos = tf.tile(cos, [batch_size, 1, 1])
311+
312+
# Ensure dtype matches input tensor x
313+
sin = tf.cast(sin, x.dtype)
314+
cos = tf.cast(cos, x.dtype)
315+
316+
# Return sin and cos needed by InterleavedRoPE
284317
return sin, cos
285318

319+
def get_config(self):
320+
config = super().get_config()
321+
config.update({
322+
"dim": self.dim,
323+
"max_seq_len": self.max_seq_len,
324+
"temperature": self.temperature,
325+
})
326+
return config
327+
328+
@classmethod
329+
def from_config(cls, config):
330+
return cls(**config)
331+
332+
286333

287334

288335
def split_alternate(x):

0 commit comments

Comments
 (0)