@@ -211,7 +211,7 @@ def from_config(cls, config):
211211
212212
213213
214- class RotaryEmbedding (tf .keras .layers .Layer ):
214+ def RotaryEmbedding (tf .keras .layers .Layer ):
215215 def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
216216 super ().__init__ (** kwargs )
217217 self .dim = dim
@@ -233,11 +233,18 @@ def call(self, x, seq_len=None):
233233 seq_len = tf .shape (x )[1 ] if seq_len is None else seq_len
234234 sin = self .sin_cache [:seq_len ]
235235 cos = self .cos_cache [:seq_len ]
236- sin = tf .cast (tf .repeat (sin [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
237- cos = tf .cast (tf .repeat (cos [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
236+ sin = tf .cast (tf .repeat (sin [..., tf .newaxis ], 2 , axis = - 1 ), x .dtype )
237+ cos = tf .cast (tf .repeat (cos [..., tf .newaxis ], 2 , axis = - 1 ), x .dtype )
238+ sin = tf .reshape (sin , [seq_len , self .dim ])
239+ cos = tf .reshape (cos , [seq_len , self .dim ])
240+ sin = tf .expand_dims (sin , axis = 0 )
241+ cos = tf .expand_dims (cos , axis = 0 )
242+ sin = tf .tile (sin , [batch_size , 1 , 1 ])
243+ cos = tf .tile (cos , [batch_size , 1 , 1 ])
238244 return sin , cos
239245
240246
247+
241248def split_alternate (x ):
242249 shape = tf .shape (x )
243250 x = tf .reshape (x , [shape [0 ], shape [1 ], shape [2 ] // 2 , 2 ])
0 commit comments