@@ -251,67 +251,63 @@ def from_config(cls, config):
251251
252252
253253
254+ # --- Updated RotaryEmbedding ---
254255class RotaryEmbedding (tf .keras .layers .Layer ):
255256 def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
256257 super ().__init__ (** kwargs )
257258 self .dim = dim
258- self .max_seq_len = max_seq_len # Still useful for potential pre-allocation if needed, but not for caching tensors
259+ # Ensure dim is even right at initialization
260+ if self .dim % 2 != 0 :
261+ raise ValueError (f"Embedding dimension `dim` ({ self .dim } ) must be even for RotaryEmbedding." )
262+ self .max_seq_len = max_seq_len
259263 self .temperature = temperature
260- # No caching in __init__ or build anymore
264+ # *** No calculation or storage of inv_freq here or in build ***
261265
262266 def build (self , input_shape ):
263- # Build is primarily for creating weights. We don't have trainable weights here.
264- # We can calculate inv_freq here if desired, as it doesn't depend on input shape directly
265- # and is constant. However, calculating it in call() is also fine.
266- # Let's calculate it once here to avoid recomputing constants.
267- # Ensure dim is even
268- if self .dim % 2 != 0 :
269- raise ValueError (f"Embedding dimension `dim` ({ self .dim } ) must be even for RotaryEmbedding." )
270-
271- inv_freq_base = tf .range (0 , self .dim , 2 , dtype = tf .float32 ) # Corrected range for pair dimension
272- inv_freq = 1.0 / (self .temperature ** (inv_freq_base / self .dim )) # Corrected calculation
273- self .inv_freq = inv_freq # Store the constant factor
267+ # Build should primarily be for creating trainable weights, which we don't have.
268+ # Call super().build() for Keras compatibility.
274269 super ().build (input_shape )
275270
276- def call (self , x , seq_len = None ):
271+ def call (self , x ): # Removed seq_len argument, calculate from x
277272 shape = tf .shape (x )
278273 batch_size = shape [0 ]
279- # Determine sequence length dynamically from input tensor 'x'
280274 actual_seq_len = shape [1 ]
281275
276+ # *** Calculate inv_freq inside call ***
277+ inv_freq_base = tf .range (0 , self .dim , 2 , dtype = tf .float32 )
278+ inv_freq = 1.0 / (self .temperature ** (inv_freq_base / self .dim ))
279+ # Ensure inv_freq has the correct shape [dim/2]
280+ inv_freq = tf .cast (inv_freq , dtype = x .dtype ) # Match dtype early
281+
282282 # Use actual_seq_len for calculations
283- position = tf .range (actual_seq_len , dtype = tf .float32 )
283+ position = tf .range (actual_seq_len , dtype = x .dtype ) # Match dtype
284+
284285 # Calculate sinusoid input using einsum or broadcasting
285- # Einsum approach:
286- sinusoid_inp = tf .einsum ("i,j->ij" , position , self .inv_freq )
287- # Broadcasting approach (might be clearer):
288- # sinusoid_inp = tf.expand_dims(position, axis=-1) * tf.expand_dims(self.inv_freq, axis=0)
286+ # Einsum approach: Ensure correct dimensions [seq_len, dim/2]
287+ sinusoid_inp = tf .einsum ("i,j->ij" , position , inv_freq )
289288
290289 # Calculate sin and cos based on the actual sequence length
291290 sin = tf .sin (sinusoid_inp )
292291 cos = tf .cos (sinusoid_inp )
293292
294293 # Repeat sin/cos for interleaving: [a, b] -> [a, a, b, b]
295- # Original code used repeat then reshape, which might be slightly different
296- # from direct interleaving depending on interpretation. Let's stick to the
297- # original logic's apparent intent which leads to pairing.
298- # We need shape [actual_seq_len, dim]
299- # sin/cos currently [actual_seq_len, dim/2]
300- sin = tf .repeat (sin , 2 , axis = - 1 ) # Repeat along the last dimension
301- cos = tf .repeat (cos , 2 , axis = - 1 ) # Repeat along the last dimension
294+ # Result needs shape [actual_seq_len, dim]
295+ sin = tf .repeat (sin , 2 , axis = - 1 )
296+ cos = tf .repeat (cos , 2 , axis = - 1 )
302297
303298 # Expand dims for batch and tile
304299 # Output shape needs to be [batch_size, actual_seq_len, dim]
305- sin = tf .expand_dims (sin , axis = 0 ) # Shape [1, actual_seq_len, dim]
306- cos = tf .expand_dims (cos , axis = 0 ) # Shape [1, actual_seq_len, dim]
300+ # Add batch dimension: [1, actual_seq_len, dim]
301+ sin = tf .expand_dims (sin , axis = 0 )
302+ cos = tf .expand_dims (cos , axis = 0 )
307303
308- # Tile to match the batch size
304+ # Tile to match the batch size: [batch_size, actual_seq_len, dim]
309305 sin = tf .tile (sin , [batch_size , 1 , 1 ])
310306 cos = tf .tile (cos , [batch_size , 1 , 1 ])
311307
312- # Ensure dtype matches input tensor x
313- sin = tf .cast (sin , x .dtype )
314- cos = tf .cast (cos , x .dtype )
308+ # Casting to x. dtype was already done for inv_freq, sin/cos will inherit
309+ # sin = tf.cast(sin, x.dtype) # Already done via calculation chain
310+ # cos = tf.cast(cos, x.dtype) # Already done via calculation chain
315311
316312 # Return sin and cos needed by InterleavedRoPE
317313 return sin , cos
@@ -332,6 +328,7 @@ def from_config(cls, config):
332328
333329
334330
331+
335332def split_alternate (x ):
336333 shape = tf .shape (x )
337334 x = tf .reshape (x , [shape [0 ], shape [1 ], shape [2 ] // 2 , 2 ])
@@ -357,17 +354,37 @@ def apply_rotary_pos_emb(x, sin, cos):
357354class InterleavedRoPE (tf .keras .layers .Layer ):
358355 def __init__ (self , dim , max_seq_len = 1024 , ** kwargs ):
359356 super ().__init__ (** kwargs )
357+ if dim % 2 != 0 :
358+ raise ValueError (f"Embedding dimension `dim` ({ dim } ) must be even for InterleavedRoPE." )
360359 self .dim = dim
361360 self .max_seq_len = max_seq_len
362- self .rotary_emb = RotaryEmbedding (dim , max_seq_len )
361+ # Instantiate the RotaryEmbedding layer
362+ # Ensure the name is consistent if needed for saving/loading
363+ self .rotary_emb = RotaryEmbedding (dim , max_seq_len , name = "rotary_embedding" )
363364
364365 def call (self , x ):
365- batch_size = tf .shape (x )[0 ]
366- seq_len = tf .shape (x )[1 ]
367-
368- sin , cos = self .rotary_emb (x , seq_len )
369- x = apply_rotary_pos_emb (x , sin , cos )
370- return x
366+ # Get sin and cos from the RotaryEmbedding layer's call method
367+ # *** Pass only 'x'. RotaryEmbedding calculates seq_len internally. ***
368+ sin , cos = self .rotary_emb (x )
369+
370+ # Apply the positional embeddings
371+ x_embedded = apply_rotary_pos_emb (x , sin , cos )
372+ return x_embedded
373+
374+ def get_config (self ):
375+ config = super ().get_config ()
376+ config .update ({
377+ "dim" : self .dim ,
378+ "max_seq_len" : self .max_seq_len ,
379+ })
380+ # Keras handles nested layer serialization automatically
381+ return config
382+
383+ @classmethod
384+ def from_config (cls , config ):
385+ # Keras handles nested layer restoration automatically
386+ return cls (** config )
387+
371388
372389
373390
0 commit comments