@@ -141,7 +141,7 @@ def prepare_data(
141141
142142
143143# --- Base Rotary Positional Embedding
144- @tf .keras .utils .register_keras_serializable ()
144+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'RotaryEmbedding' )
145145class RotaryEmbedding (tf .keras .layers .Layer ):
146146 def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
147147 super ().__init__ (** kwargs )
@@ -218,7 +218,7 @@ def from_config(cls, config):
218218
219219# iRoPE helper functions
220220
221- @tf .keras .utils .register_keras_serializable ()
221+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'split_alternate' )
222222def split_alternate (x ):
223223 shape = tf .shape (x )
224224 x = tf .reshape (x , [shape [0 ], shape [1 ], shape [2 ] // 2 , 2 ])
@@ -227,15 +227,15 @@ def split_alternate(x):
227227 return x
228228
229229
230- @tf .keras .utils .register_keras_serializable ()
230+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'rotate_half' )
231231def rotate_half (x ):
232232 x = split_alternate (x )
233233 d = tf .shape (x )[- 1 ]
234234 rotated_x = tf .concat ([- x [..., d // 2 :], x [..., :d // 2 ]], axis = - 1 )
235235 return tf .reshape (rotated_x , tf .shape (x ))
236236
237237
238- @tf .keras .utils .register_keras_serializable ()
238+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'apply_rotary_pos_emb' )
239239def apply_rotary_pos_emb (x , sin , cos ):
240240 cos = tf .reshape (cos , [tf .shape (cos )[0 ], tf .shape (cos )[1 ], - 1 ])
241241 sin = tf .reshape (sin , [tf .shape (sin )[0 ], tf .shape (sin )[1 ], - 1 ])
@@ -244,7 +244,7 @@ def apply_rotary_pos_emb(x, sin, cos):
244244
245245
246246# interleaved Rotary Postional Embedding (iRoPE)
247- @tf .keras .utils .register_keras_serializable ()
247+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'InterleavedRoPE' )
248248class InterleavedRoPE (tf .keras .layers .Layer ):
249249 def __init__ (self , dim , max_seq_len = 1024 , ** kwargs ):
250250 super ().__init__ (** kwargs )
@@ -280,7 +280,7 @@ def from_config(cls, config):
280280 return cls (** config )
281281
282282
283- @tf .keras .utils .register_keras_serializable ()
283+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'Perplexity' )
284284class Perplexity (tf .keras .metrics .Metric ):
285285 """
286286 Computes perplexity, defined as e^(categorical crossentropy).
@@ -310,7 +310,7 @@ def reset_state(self):
310310 self .total_crossentropy .assign (0.0 )
311311 self .count .assign (0.0 )
312312
313- @tf .keras .utils .register_keras_serializable ()
313+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPTConfig' )
314314class CerebrosNotGPTConfig :
315315 def __init__ (self , max_sequence_length = 1536 , padding_token = None ):
316316 self .max_sequence_length = max_sequence_length
@@ -328,7 +328,7 @@ def from_config(cls, config):
328328 return cls (** config ) # No model_0 to handle
329329
330330
331- @tf .keras .utils .register_keras_serializable ()
331+ @tf .keras .utils .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPT' )
332332class CerebrosNotGPT (tf .keras .Model ):
333333 def __init__ (self , config , model_0 = None , ** kwargs ):
334334 super ().__init__ (** kwargs )
0 commit comments