Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions extensions_built_in/sd_trainer/DiffusionTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):

if self.is_ui_trainer:
self.is_stopping = False
self._is_saving = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
Expand Down Expand Up @@ -140,7 +141,19 @@ def _check_return_to_queue():
def maybe_stop(self):
if not self.is_ui_trainer:
return
if self._is_saving or self.is_stopping:
return
if self.should_stop():
# Trigger an immediate save of LoRA weights and optimizer.pt.
# This ensures the current step is recorded in the Metadata,
# allowing the user to resume from this exact point later.
print(f"\n[Stop Signal Detected] Saving emergency checkpoint at step {self.step_num}...")
self._is_saving = True
try:
self.save(self.step_num)
finally:
self._is_saving = False

self._run_async_operation(
self._update_status("stopped", "Job stopped"))
self.is_stopping = True
Expand Down Expand Up @@ -308,8 +321,10 @@ def sample(self, step=None, is_first=False):
self.update_status("running", "Training")

def save(self, step=None):
self.maybe_stop()
if not self._is_saving:
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
if not self._is_saving:
self.maybe_stop()
self.update_status("running", "Training")
13 changes: 12 additions & 1 deletion jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,23 @@ def save(self, step=None):
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
backup_optimizer_path = os.path.join(self.save_root, 'optimizer_prev.pt')

# If optimizer.pt already exists, rename it to optimizer_prev.pt
if os.path.exists(file_path):
# If an old backup also exists, delete it first (to ensure a successful rename)
if os.path.exists(backup_optimizer_path):
os.remove(backup_optimizer_path)

os.rename(file_path, backup_optimizer_path)
print_acc(f"Existing optimizer.pt moved to optimizer_prev.pt")

try:
state_dict = unwrap_model(self.optimizer).state_dict()
except Exception as e:
state_dict = self.optimizer.state_dict()
torch.save(state_dict, file_path)
print_acc(f"Saved optimizer to {file_path}")
print_acc(f"Saved latest optimizer to {file_path}")
except Exception as e:
print_acc(e)
print_acc("Could not save optimizer")
Expand Down