Skip to content

Commit f46ad78

Browse files
Update phishing_email_detection_gpt2.py
Another attempt to resolve tf v 2.19.0 graph scope compatibility...
1 parent 44854be commit f46ad78

File tree

1 file changed

+57
-40
lines changed

1 file changed

+57
-40
lines changed

phishing_email_detection_gpt2.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -251,67 +251,63 @@ def from_config(cls, config):
251251

252252

253253

254+
# --- Updated RotaryEmbedding ---
254255
class RotaryEmbedding(tf.keras.layers.Layer):
255256
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
256257
super().__init__(**kwargs)
257258
self.dim = dim
258-
self.max_seq_len = max_seq_len # Still useful for potential pre-allocation if needed, but not for caching tensors
259+
# Ensure dim is even right at initialization
260+
if self.dim % 2 != 0:
261+
raise ValueError(f"Embedding dimension `dim` ({self.dim}) must be even for RotaryEmbedding.")
262+
self.max_seq_len = max_seq_len
259263
self.temperature = temperature
260-
# No caching in __init__ or build anymore
264+
# *** No calculation or storage of inv_freq here or in build ***
261265

262266
def build(self, input_shape):
263-
# Build is primarily for creating weights. We don't have trainable weights here.
264-
# We can calculate inv_freq here if desired, as it doesn't depend on input shape directly
265-
# and is constant. However, calculating it in call() is also fine.
266-
# Let's calculate it once here to avoid recomputing constants.
267-
# Ensure dim is even
268-
if self.dim % 2 != 0:
269-
raise ValueError(f"Embedding dimension `dim` ({self.dim}) must be even for RotaryEmbedding.")
270-
271-
inv_freq_base = tf.range(0, self.dim, 2, dtype=tf.float32) # Corrected range for pair dimension
272-
inv_freq = 1.0 / (self.temperature ** (inv_freq_base / self.dim)) # Corrected calculation
273-
self.inv_freq = inv_freq # Store the constant factor
267+
# Build should primarily be for creating trainable weights, which we don't have.
268+
# Call super().build() for Keras compatibility.
274269
super().build(input_shape)
275270

276-
def call(self, x, seq_len=None):
271+
def call(self, x): # Removed seq_len argument, calculate from x
277272
shape = tf.shape(x)
278273
batch_size = shape[0]
279-
# Determine sequence length dynamically from input tensor 'x'
280274
actual_seq_len = shape[1]
281275

276+
# *** Calculate inv_freq inside call ***
277+
inv_freq_base = tf.range(0, self.dim, 2, dtype=tf.float32)
278+
inv_freq = 1.0 / (self.temperature ** (inv_freq_base / self.dim))
279+
# Ensure inv_freq has the correct shape [dim/2]
280+
inv_freq = tf.cast(inv_freq, dtype=x.dtype) # Match dtype early
281+
282282
# Use actual_seq_len for calculations
283-
position = tf.range(actual_seq_len, dtype=tf.float32)
283+
position = tf.range(actual_seq_len, dtype=x.dtype) # Match dtype
284+
284285
# Calculate sinusoid input using einsum or broadcasting
285-
# Einsum approach:
286-
sinusoid_inp = tf.einsum("i,j->ij", position, self.inv_freq)
287-
# Broadcasting approach (might be clearer):
288-
# sinusoid_inp = tf.expand_dims(position, axis=-1) * tf.expand_dims(self.inv_freq, axis=0)
286+
# Einsum approach: Ensure correct dimensions [seq_len, dim/2]
287+
sinusoid_inp = tf.einsum("i,j->ij", position, inv_freq)
289288

290289
# Calculate sin and cos based on the actual sequence length
291290
sin = tf.sin(sinusoid_inp)
292291
cos = tf.cos(sinusoid_inp)
293292

294293
# Repeat sin/cos for interleaving: [a, b] -> [a, a, b, b]
295-
# Original code used repeat then reshape, which might be slightly different
296-
# from direct interleaving depending on interpretation. Let's stick to the
297-
# original logic's apparent intent which leads to pairing.
298-
# We need shape [actual_seq_len, dim]
299-
# sin/cos currently [actual_seq_len, dim/2]
300-
sin = tf.repeat(sin, 2, axis=-1) # Repeat along the last dimension
301-
cos = tf.repeat(cos, 2, axis=-1) # Repeat along the last dimension
294+
# Result needs shape [actual_seq_len, dim]
295+
sin = tf.repeat(sin, 2, axis=-1)
296+
cos = tf.repeat(cos, 2, axis=-1)
302297

303298
# Expand dims for batch and tile
304299
# Output shape needs to be [batch_size, actual_seq_len, dim]
305-
sin = tf.expand_dims(sin, axis=0) # Shape [1, actual_seq_len, dim]
306-
cos = tf.expand_dims(cos, axis=0) # Shape [1, actual_seq_len, dim]
300+
# Add batch dimension: [1, actual_seq_len, dim]
301+
sin = tf.expand_dims(sin, axis=0)
302+
cos = tf.expand_dims(cos, axis=0)
307303

308-
# Tile to match the batch size
304+
# Tile to match the batch size: [batch_size, actual_seq_len, dim]
309305
sin = tf.tile(sin, [batch_size, 1, 1])
310306
cos = tf.tile(cos, [batch_size, 1, 1])
311307

312-
# Ensure dtype matches input tensor x
313-
sin = tf.cast(sin, x.dtype)
314-
cos = tf.cast(cos, x.dtype)
308+
# Casting to x.dtype was already done for inv_freq, sin/cos will inherit
309+
# sin = tf.cast(sin, x.dtype) # Already done via calculation chain
310+
# cos = tf.cast(cos, x.dtype) # Already done via calculation chain
315311

316312
# Return sin and cos needed by InterleavedRoPE
317313
return sin, cos
@@ -332,6 +328,7 @@ def from_config(cls, config):
332328

333329

334330

331+
335332
def split_alternate(x):
336333
shape = tf.shape(x)
337334
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])
@@ -357,17 +354,37 @@ def apply_rotary_pos_emb(x, sin, cos):
357354
class InterleavedRoPE(tf.keras.layers.Layer):
358355
def __init__(self, dim, max_seq_len=1024, **kwargs):
359356
super().__init__(**kwargs)
357+
if dim % 2 != 0:
358+
raise ValueError(f"Embedding dimension `dim` ({dim}) must be even for InterleavedRoPE.")
360359
self.dim = dim
361360
self.max_seq_len = max_seq_len
362-
self.rotary_emb = RotaryEmbedding(dim, max_seq_len)
361+
# Instantiate the RotaryEmbedding layer
362+
# Ensure the name is consistent if needed for saving/loading
363+
self.rotary_emb = RotaryEmbedding(dim, max_seq_len, name="rotary_embedding")
363364

364365
def call(self, x):
365-
batch_size = tf.shape(x)[0]
366-
seq_len = tf.shape(x)[1]
367-
368-
sin, cos = self.rotary_emb(x, seq_len)
369-
x = apply_rotary_pos_emb(x, sin, cos)
370-
return x
366+
# Get sin and cos from the RotaryEmbedding layer's call method
367+
# *** Pass only 'x'. RotaryEmbedding calculates seq_len internally. ***
368+
sin, cos = self.rotary_emb(x)
369+
370+
# Apply the positional embeddings
371+
x_embedded = apply_rotary_pos_emb(x, sin, cos)
372+
return x_embedded
373+
374+
def get_config(self):
375+
config = super().get_config()
376+
config.update({
377+
"dim": self.dim,
378+
"max_seq_len": self.max_seq_len,
379+
})
380+
# Keras handles nested layer serialization automatically
381+
return config
382+
383+
@classmethod
384+
def from_config(cls, config):
385+
# Keras handles nested layer restoration automatically
386+
return cls(**config)
387+
371388

372389

373390

0 commit comments

Comments
 (0)