@@ -209,6 +209,7 @@ def from_config(cls, config):
209209 return cls (max_seq_length = config ['max_seq_length' ])
210210
211211
212+
212213class RotaryPositionEmbedding (tf .keras .layers .Layer ):
213214 def __init__ (self , max_seq_length , d_model , ** kwargs ):
214215 super ().__init__ (** kwargs )
@@ -218,32 +219,36 @@ def __init__(self, max_seq_length, d_model, **kwargs):
218219
219220 # Precompute rotation matrices
220221 inv_freq = 1.0 / (10000 ** (tf .range (0 , d_model , 2 , dtype = tf .float32 ) / d_model ))
222+ self .inv_freq = tf .cast (inv_freq , tf .float32 )
221223 positions = tf .range (max_seq_length , dtype = tf .float32 )
222- sinusoid = tf .einsum ('i,j->ij' , positions , inv_freq )
223-
224- self .sin = tf .sin (sinusoid )
225- self .cos = tf .cos (sinusoid )
224+ self .sin = tf .sin (tf .einsum ('i,j->ij' , positions , inv_freq ))
225+ self .cos = tf .cos (tf .einsum ('i,j->ij' , positions , inv_freq ))
226226
227227 def call (self , x ):
228228 batch_size = tf .shape (x )[0 ]
229229 seq_len = tf .shape (x )[1 ]
230230
231- # Split dimensions into pairs
232- x = tf .reshape (x , [batch_size , seq_len , self .d_model // 2 , 2 ])
231+ # Compute sine and cosine matrices for current sequence length
232+ sinusoid = tf .einsum ('i,j->ij' , tf .range (seq_len , dtype = tf .float32 ), self .inv_freq )
233+ current_sin = tf .sin (sinusoid )
234+ current_cos = tf .cos (sinusoid )
233235
234- # Apply rotation
235- x_rot = tf .stack ([
236- x [..., 0 ] * self .cos [:seq_len ] - x [..., 1 ] * self .sin [:seq_len ],
237- x [..., 0 ] * self .sin [:seq_len ] + x [..., 1 ] * self .cos [:seq_len ]
236+ # Split dimensions and apply rotation using einsum
237+ x = tf .reshape (x , [batch_size , seq_len , self .d_model // 2 , 2 ])
238+ rotated = tf .stack ([
239+ x [..., 0 ] * current_cos - x [..., 1 ] * current_sin ,
240+ x [..., 0 ] * current_sin + x [..., 1 ] * current_cos
238241 ], axis = - 1 )
239242
240- return tf .reshape (x_rot , [batch_size , seq_len , self .d_model ])
243+ # Reshape back and apply dropout
244+ return tf .reshape (rotated , [batch_size , seq_len , self .d_model ])
245+
246+
241247
242248
243249# GPT2 configurables
244250
245251# Optimal for accuracy thus far:
246- # max_seq_length = 900
247252max_seq_length = 1024
248253
249254inp = tf .keras .layers .Input (shape = (), dtype = tf .string )
0 commit comments