File tree Expand file tree Collapse file tree 1 file changed +22
-0
lines changed
Expand file tree Collapse file tree 1 file changed +22
-0
lines changed Original file line number Diff line number Diff line change 1717
1818import abc
1919from collections .abc import Mapping
20+ import dataclasses
2021import gc
2122import os
2223import time
@@ -232,6 +233,27 @@ def evaluate(self, task: KerasTask) -> core.Logs:
232233 model = task .create_model_for_eval (
233234 ** self ._maybe_get_model_kws (task , dataset )
234235 )
236+
237+ if keras .backend .backend () == "jax" :
238+ [tb_cbk ] = [
239+ cbk
240+ for cbk in self ._eval_callbacks
241+ if isinstance (cbk , keras_utils .EpochSummaryCallback )
242+ ]
243+ epoch_start_time = time .time ()
244+ history = model .evaluate (
245+ dataset ,
246+ steps = self ._steps_per_eval ,
247+ callbacks = self ._eval_callbacks ,
248+ return_dict = True ,
249+ )
250+ epoch_dt = time .time () - epoch_start_time
251+ steps_per_second = self ._steps_per_eval / epoch_dt
252+ val_logs = {"val_" + k : v for k , v in history .items ()}
253+ val_logs ["val_steps_per_second" ] = steps_per_second
254+ tb_cbk .on_epoch_end (0 , val_logs )
255+ return history
256+
235257 return model .evaluate (
236258 dataset ,
237259 steps = self ._steps_per_eval ,
You can’t perform that action at this time.
0 commit comments