Skip to content

Commit 9a923ca

Browse files
Update llm_utils.py
Fix AI generated mistakes ...
1 parent de444e4 commit 9a923ca

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def reset_state(self):
311311
self.count.assign(0.0)
312312

313313

314-
@keras.saving.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPTConfig')
314+
@tf.keras.saving.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPTConfig')
315315
class CerebrosNotGPTConfig:
316316
def __init__(self, max_sequence_length=1536, padding_token=None):
317317
self.max_sequence_length = max_sequence_length
@@ -327,16 +327,16 @@ def get_config(self):
327327
def from_config(cls, config):
328328
return cls(**config)
329329

330-
@keras.saving.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPT')
331-
class CerebrosNotGPT(keras.Model):
330+
@tf.keras.saving.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPT')
331+
class CerebrosNotGPT(tf.keras.Model):
332332
def __init__(self, config, **kwargs):
333333
super().__init__(**kwargs)
334334
self.config = config
335335
self.max_sequence_length = config.max_sequence_length
336336
self.padding_token = config.padding_token
337337

338338
# This `self.model` attribute is the key. Keras automatically tracks
339-
# and serializes any `keras.Layer` or `keras.Model` assigned as an attribute.
339+
# and serializes any `tf.keras.Layer` or `tf.keras.Model` assigned as an attribute.
340340
# It is set during the initial object creation from your functional model.
341341
# During deserialization, Keras will handle the restoration of this nested model.
342342
# Do not manually create or deserialize it in get_config/from_config.

0 commit comments

Comments
 (0)