Skip to content

Commit 0eac01b

Browse files
merveenoyannateraw
authored andcommitted
Keras: Saving history in a JSON file (#861)
* added test and history saving * Update src/huggingface_hub/keras_mixin.py Co-authored-by: Nathan Raw <[email protected]> Co-authored-by: Nathan Raw <[email protected]>
1 parent a8b6f14 commit 0eac01b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/huggingface_hub/keras_mixin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ def _extract_hyperparameters_from_keras(model):
3939
return hyperparameters
4040

4141

42-
def _parse_model_history(model):
42+
def _parse_model_history(model, save_directory):
4343
lines = None
4444
if model.history is not None:
4545
if model.history.history != {}:
46+
path = os.path.join(save_directory, "history.json")
47+
with open(path, "w", encoding="utf-8") as f:
48+
json.dump(model.history.history, f, indent=2, sort_keys=True)
4649
lines = []
4750
logs = model.history.history
4851
num_epochs = len(logs["loss"])
@@ -79,8 +82,8 @@ def _plot_network(model, save_directory):
7982
)
8083

8184

82-
def _write_metrics(model, model_card):
83-
lines = _parse_model_history(model)
85+
def _write_metrics(model, model_card, save_directory):
86+
lines = _parse_model_history(model, save_directory)
8487
if lines is not None:
8588
model_card += "\n| Epochs |"
8689

@@ -128,7 +131,7 @@ def _create_model_card(
128131
)
129132
model_card += "\n"
130133
model_card += "\n ## Training Metrics\n"
131-
model_card = _write_metrics(model, model_card)
134+
model_card = _write_metrics(model, model_card, repo_dir)
132135
if plot_model and os.path.exists(f"{repo_dir}/model.png"):
133136
model_card += "\n ## Model Plot\n"
134137
model_card += "\n<details>"

tests/test_keras_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ def test_save_pretrained_model_card_fit(self):
225225
self.assertIn("keras_metadata.pb", files)
226226
self.assertIn("model.png", files)
227227
self.assertIn("README.md", files)
228-
self.assertEqual(len(files), 6)
228+
self.assertIn("history.json", files)
229+
self.assertEqual(len(files), 7)
229230

230231
def test_save_pretrained_optimizer_state(self):
231232
REPO_NAME = repo_name("save")
@@ -490,4 +491,4 @@ def test_save_pretrained_fit(self):
490491

491492
self.assertIn("saved_model.pb", files)
492493
self.assertIn("keras_metadata.pb", files)
493-
self.assertEqual(len(files), 6)
494+
self.assertEqual(len(files), 7)

0 commit comments

Comments
 (0)