@@ -311,7 +311,7 @@ def reset_state(self):
311311 self .count .assign (0.0 )
312312
313313
314- @keras .saving .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPTConfig' )
314+ @tf . keras .saving .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPTConfig' )
315315class CerebrosNotGPTConfig :
316316 def __init__ (self , max_sequence_length = 1536 , padding_token = None ):
317317 self .max_sequence_length = max_sequence_length
@@ -327,16 +327,16 @@ def get_config(self):
327327 def from_config (cls , config ):
328328 return cls (** config )
329329
330- @keras .saving .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPT' )
331- class CerebrosNotGPT (keras .Model ):
330+ @tf . keras .saving .register_keras_serializable (package = 'cerebrosllmutils' , name = 'CerebrosNotGPT' )
331+ class CerebrosNotGPT (tf . keras .Model ):
332332 def __init__ (self , config , ** kwargs ):
333333 super ().__init__ (** kwargs )
334334 self .config = config
335335 self .max_sequence_length = config .max_sequence_length
336336 self .padding_token = config .padding_token
337337
338338 # This `self.model` attribute is the key. Keras automatically tracks
339- # and serializes any `keras.Layer` or `keras.Model` assigned as an attribute.
339+ # and serializes any `tf. keras.Layer` or `tf. keras.Model` assigned as an attribute.
340340 # It is set during the initial object creation from your functional model.
341341 # During deserialization, Keras will handle the restoration of this nested model.
342342 # Do not manually create or deserialize it in get_config/from_config.
0 commit comments