@@ -255,34 +255,81 @@ class RotaryEmbedding(tf.keras.layers.Layer):
255255 def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
256256 super ().__init__ (** kwargs )
257257 self .dim = dim
258- self .max_seq_len = max_seq_len
258+ self .max_seq_len = max_seq_len # Still useful for potential pre-allocation if needed, but not for caching tensors
259259 self .temperature = temperature
260+ # No caching in __init__ or build anymore
260261
261262 def build (self , input_shape ):
263+ # Build is primarily for creating weights. We don't have trainable weights here.
264+ # We can calculate inv_freq here if desired, as it doesn't depend on input shape directly
265+ # and is constant. However, calculating it in call() is also fine.
266+ # Let's calculate it once here to avoid recomputing constants.
267+ # Ensure dim is even
268+ if self .dim % 2 != 0 :
269+ raise ValueError (f"Embedding dimension `dim` ({ self .dim } ) must be even for RotaryEmbedding." )
270+
271+ inv_freq_base = tf .range (0 , self .dim , 2 , dtype = tf .float32 ) # Corrected range for pair dimension
272+ inv_freq = 1.0 / (self .temperature ** (inv_freq_base / self .dim )) # Corrected calculation
273+ self .inv_freq = inv_freq # Store the constant factor
262274 super ().build (input_shape )
263- inv_freq = 1.0 / (self .temperature ** (tf .range (0 , self .dim // 2 , dtype = tf .float32 ) / (self .dim // 2 )))
264- position = tf .range (self .max_seq_len , dtype = tf .float32 )
265- sinusoid = tf .einsum ("i,j->ij" , position , inv_freq )
266- sin = tf .sin (sinusoid )
267- cos = tf .cos (sinusoid )
268- self .sin_cache = sin
269- self .cos_cache = cos
270-
275+
271276 def call (self , x , seq_len = None ):
272- batch_size = tf .shape (x )[0 ]
273- seq_len = tf .shape (x )[1 ] if seq_len is None else seq_len
274- sin = self .sin_cache [:seq_len ]
275- cos = self .cos_cache [:seq_len ]
276- sin = tf .cast (tf .repeat (sin [..., tf .newaxis ], 2 , axis = - 1 ), x .dtype )
277- cos = tf .cast (tf .repeat (cos [..., tf .newaxis ], 2 , axis = - 1 ), x .dtype )
278- sin = tf .reshape (sin , [seq_len , self .dim ])
279- cos = tf .reshape (cos , [seq_len , self .dim ])
280- sin = tf .expand_dims (sin , axis = 0 )
281- cos = tf .expand_dims (cos , axis = 0 )
277+ shape = tf .shape (x )
278+ batch_size = shape [0 ]
279+ # Determine sequence length dynamically from input tensor 'x'
280+ actual_seq_len = shape [1 ]
281+
282+ # Use actual_seq_len for calculations
283+ position = tf .range (actual_seq_len , dtype = tf .float32 )
284+ # Calculate sinusoid input using einsum or broadcasting
285+ # Einsum approach:
286+ sinusoid_inp = tf .einsum ("i,j->ij" , position , self .inv_freq )
287+ # Broadcasting approach (might be clearer):
288+ # sinusoid_inp = tf.expand_dims(position, axis=-1) * tf.expand_dims(self.inv_freq, axis=0)
289+
290+ # Calculate sin and cos based on the actual sequence length
291+ sin = tf .sin (sinusoid_inp )
292+ cos = tf .cos (sinusoid_inp )
293+
294+ # Repeat sin/cos for interleaving: [a, b] -> [a, a, b, b]
295+ # Original code used repeat then reshape, which might be slightly different
296+ # from direct interleaving depending on interpretation. Let's stick to the
297+ # original logic's apparent intent which leads to pairing.
298+ # We need shape [actual_seq_len, dim]
299+ # sin/cos currently [actual_seq_len, dim/2]
300+ sin = tf .repeat (sin , 2 , axis = - 1 ) # Repeat along the last dimension
301+ cos = tf .repeat (cos , 2 , axis = - 1 ) # Repeat along the last dimension
302+
303+ # Expand dims for batch and tile
304+ # Output shape needs to be [batch_size, actual_seq_len, dim]
305+ sin = tf .expand_dims (sin , axis = 0 ) # Shape [1, actual_seq_len, dim]
306+ cos = tf .expand_dims (cos , axis = 0 ) # Shape [1, actual_seq_len, dim]
307+
308+ # Tile to match the batch size
282309 sin = tf .tile (sin , [batch_size , 1 , 1 ])
283310 cos = tf .tile (cos , [batch_size , 1 , 1 ])
311+
312+ # Ensure dtype matches input tensor x
313+ sin = tf .cast (sin , x .dtype )
314+ cos = tf .cast (cos , x .dtype )
315+
316+ # Return sin and cos needed by InterleavedRoPE
284317 return sin , cos
285318
319+ def get_config (self ):
320+ config = super ().get_config ()
321+ config .update ({
322+ "dim" : self .dim ,
323+ "max_seq_len" : self .max_seq_len ,
324+ "temperature" : self .temperature ,
325+ })
326+ return config
327+
328+ @classmethod
329+ def from_config (cls , config ):
330+ return cls (** config )
331+
332+
286333
287334
288335def split_alternate (x ):
0 commit comments