@@ -208,6 +208,38 @@ def get_config(self):
208208 def from_config (cls , config ):
209209 return cls (max_seq_length = config ['max_seq_length' ])
210210
211+
212+ class RotaryPositionEmbedding (tf .keras .layers .Layer ):
213+ def __init__ (self , max_seq_length , d_model , ** kwargs ):
214+ super ().__init__ (** kwargs )
215+ self .max_seq_length = max_seq_length
216+ self .d_model = d_model
217+ assert d_model % 2 == 0 , "d_model must be even"
218+
219+ # Precompute rotation matrices
220+ inv_freq = 1.0 / (10000 ** (tf .range (0 , d_model , 2 , dtype = tf .float32 ) / d_model ))
221+ positions = tf .range (max_seq_length , dtype = tf .float32 )
222+ sinusoid = tf .einsum ('i,j->ij' , positions , inv_freq )
223+
224+ self .sin = tf .sin (sinusoid )
225+ self .cos = tf .cos (sinusoid )
226+
227+ def call (self , x ):
228+ batch_size = tf .shape (x )[0 ]
229+ seq_len = tf .shape (x )[1 ]
230+
231+ # Split dimensions into pairs
232+ x = tf .reshape (x , [batch_size , seq_len , self .d_model // 2 , 2 ])
233+
234+ # Apply rotation
235+ x_rot = tf .stack ([
236+ x [..., 0 ] * self .cos [:seq_len ] - x [..., 1 ] * self .sin [:seq_len ],
237+ x [..., 0 ] * self .sin [:seq_len ] + x [..., 1 ] * self .cos [:seq_len ]
238+ ], axis = - 1 )
239+
240+ return tf .reshape (x_rot , [batch_size , seq_len , self .d_model ])
241+
242+
211243# GPT2 configurables
212244
213245# Optimal for accuracy thus far:
@@ -221,17 +253,19 @@ def from_config(cls, config):
221253
222254# On larger hardware, this could probably be increased considerably and
223255# Probably would improve performance ...
224- EMBEDDING_DIM = 23 # Define EMBEDDING_DIM here, to match your embedding layer.
256+ EMBEDDING_N = 12 # Define EMBEDDING_DIM here, to match your embedding layer.
257+ EMBEDDING_DIM = int (EMBEDDING_N * 2 )
225258
226259embedded = tf .keras .layers .Embedding (
227260 input_dim = VOCABULARY_SIZE ,
228261 output_dim = EMBEDDING_DIM ,
229262 input_length = max_seq_length ,
230263 mask_zero = True )(tokens )
231264
232- position_embedding = PositionEmbedding (
265+ position_embedding = RotaryPositionEmbedding (
233266 sequence_length = max_seq_length ,
234- initializer = "uniform" ,
267+ sequence_length = EMBEDDING_DIM
268+ # initializer="uniform",
235269)(embedded )
236270
237271# As an FYI, we tried an add layer both with and without
0 commit comments