Skip to content

Commit cc1e4fa

Browse files
edwardzhou130recml authors
authored andcommitted
Support loading ckpt for the eval job.
PiperOrigin-RevId: 752480017
1 parent f5e68e2 commit cc1e4fa

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

recml/core/training/keras_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import abc
1919
from collections.abc import Mapping
20+
import dataclasses
2021
import gc
2122
import os
2223
import time
@@ -232,6 +233,27 @@ def evaluate(self, task: KerasTask) -> core.Logs:
232233
model = task.create_model_for_eval(
233234
**self._maybe_get_model_kws(task, dataset)
234235
)
236+
237+
if keras.backend.backend() == "jax":
238+
[tb_cbk] = [
239+
cbk
240+
for cbk in self._eval_callbacks
241+
if isinstance(cbk, keras_utils.EpochSummaryCallback)
242+
]
243+
epoch_start_time = time.time()
244+
history = model.evaluate(
245+
dataset,
246+
steps=self._steps_per_eval,
247+
callbacks=self._eval_callbacks,
248+
return_dict=True,
249+
)
250+
epoch_dt = time.time() - epoch_start_time
251+
steps_per_second = self._steps_per_eval / epoch_dt
252+
val_logs = {"val_" + k: v for k, v in history.items()}
253+
val_logs["val_steps_per_second"] = steps_per_second
254+
tb_cbk.on_epoch_end(0, val_logs)
255+
return history
256+
235257
return model.evaluate(
236258
dataset,
237259
steps=self._steps_per_eval,

0 commit comments

Comments
 (0)