diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py index f39611b13..ae4a28c54 100644 --- a/extensions_built_in/sd_trainer/DiffusionTrainer.py +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -30,6 +30,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): if self.is_ui_trainer: self.is_stopping = False + self._is_saving = False # Create a thread pool for database operations self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) # Track all async tasks @@ -140,7 +141,19 @@ def _check_return_to_queue(): def maybe_stop(self): if not self.is_ui_trainer: return + if self._is_saving or self.is_stopping: + return if self.should_stop(): + # Trigger an immediate save of LoRA weights and optimizer.pt. + # This ensures the current step is recorded in the Metadata, + # allowing the user to resume from this exact point later. + print(f"\n[Stop Signal Detected] Saving emergency checkpoint at step {self.step_num}...") + self._is_saving = True + try: + self.save(self.step_num) + finally: + self._is_saving = False + self._run_async_operation( self._update_status("stopped", "Job stopped")) self.is_stopping = True @@ -308,8 +321,10 @@ def sample(self, step=None, is_first=False): self.update_status("running", "Training") def save(self, step=None): - self.maybe_stop() + if not self._is_saving: + self.maybe_stop() self.update_status("running", "Saving model") super().save(step) - self.maybe_stop() - self.update_status("running", "Training") + if not self._is_saving: + self.maybe_stop() + self.update_status("running", "Training") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e55525aa1..a76070ff5 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -684,12 +684,23 @@ def save(self, step=None): try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) + backup_optimizer_path = os.path.join(self.save_root, 'optimizer_prev.pt') + + # If optimizer.pt already exists, rename it to optimizer_prev.pt + if os.path.exists(file_path): + # If an old backup also exists, delete it first (to ensure a successful rename) + if os.path.exists(backup_optimizer_path): + os.remove(backup_optimizer_path) + + os.rename(file_path, backup_optimizer_path) + print_acc(f"Existing optimizer.pt moved to optimizer_prev.pt") + try: state_dict = unwrap_model(self.optimizer).state_dict() except Exception as e: state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) - print_acc(f"Saved optimizer to {file_path}") + print_acc(f"Saved latest optimizer to {file_path}") except Exception as e: print_acc(e) print_acc("Could not save optimizer")