Skip to content

Commit 4e9fa19

Browse files
committed
Merge branch 'main' of github.com:AshishKumar4/FlaxDiff
2 parents 41677f2 + b1a158b commit 4e9fa19

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

flaxdiff/trainer/simple_trainer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,19 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps
642642
self.rngstate = rng_state
643643
total_time = end_time - start_time
644644
avg_time_per_step = total_time / train_steps_per_epoch
645+
646+
if val_steps_per_epoch > 0:
647+
print(f"Validation started for process index {process_index}")
648+
# Validation step
649+
self.validation_loop(
650+
train_state,
651+
val_step,
652+
val_ds,
653+
val_steps_per_epoch,
654+
current_step,
655+
)
656+
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
657+
645658
avg_loss = epoch_loss / train_steps_per_epoch
646659
if avg_loss < self.best_loss:
647660
self.best_loss = avg_loss
@@ -659,17 +672,6 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps
659672
}, step=current_step)
660673
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
661674

662-
if val_steps_per_epoch > 0:
663-
print(f"Validation started for process index {process_index}")
664-
# Validation step
665-
self.validation_loop(
666-
train_state,
667-
val_step,
668-
val_ds,
669-
val_steps_per_epoch,
670-
current_step,
671-
)
672-
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
673675

674676
self.save(epochs)#
675677
return self.state

0 commit comments

Comments
 (0)