Skip to content

Commit 25eab2d

Browse files
Update llm_utils.py
Try to fix config for serialization ...
1 parent 6a25a89 commit 25eab2d

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,18 @@ def __init__(self, config, model_0=None, **kwargs):
345345
pass
346346

347347
def get_config(self):
348-
return {
349-
'config': self.config.get_config()
350-
# NO model reference here!
351-
}
348+
base_config = super().get_config()
349+
base_config.update({
350+
'config': self.config.get_config(),
351+
'model': tf.keras.utils.serialize_keras_object(self.model)
352+
})
353+
return base_config
352354

353355
@classmethod
354356
def from_config(cls, config):
355357
config_obj = CerebrosNotGPTConfig.from_config(config['config'])
356-
return cls(config=config_obj) # Keras will handle model restoration
358+
model_0 = tf.keras.utils.deserialize_keras_object(config['model'])
359+
return cls(config=config_obj, model_0=model_0)
357360

358361
def call(self, inputs):
359362
return self.model(inputs)

0 commit comments

Comments
 (0)