@@ -211,7 +211,7 @@ def from_config(cls, config):
211211
212212
213213
214- class RotaryEmbedding (tf .keras .layers .Layer ):
214+ def RotaryEmbedding (tf .keras .layers .Layer ):
215215 def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
216216 super ().__init__ (** kwargs )
217217 self .dim = dim
@@ -225,15 +225,18 @@ def build(self, input_shape):
225225 sinusoid = tf .einsum ("i,j->ij" , position , inv_freq )
226226 sin = tf .sin (sinusoid )
227227 cos = tf .cos (sinusoid )
228- self .sin_cache = tf . concat ([ sin , sin ], axis = - 1 )
229- self .cos_cache = tf . concat ([ cos , cos ], axis = - 1 )
228+ self .sin_cache = sin
229+ self .cos_cache = cos
230230
231231 def call (self , x , seq_len = None ):
232232 batch_size = tf .shape (x )[0 ]
233233 seq_len = tf .shape (x )[1 ] if seq_len is None else seq_len
234234 sin = self .sin_cache [:seq_len ]
235235 cos = self .cos_cache [:seq_len ]
236- return tf .cast (sin , x .dtype ), tf .cast (cos , x .dtype )
236+ sin = tf .cast (tf .repeat (sin [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
237+ cos = tf .cast (tf .repeat (cos [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
238+ return sin , cos
239+
237240
238241def split_alternate (x ):
239242 shape = tf .shape (x )
0 commit comments