Skip to content

Commit 39541d7

Browse files
author
Fampai
committed
Fixes race condition in training when VAE is unloaded
set_current_image can attempt to use the VAE when it is unloaded to the CPU while training
1 parent f2b6970 commit 39541d7

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
390390
with torch.autocast("cuda"):
391391
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
392392

393+
old_parallel_processing_allowed = shared.parallel_processing_allowed
394+
393395
if unload:
396+
shared.parallel_processing_allowed = False
394397
shared.sd_model.cond_stage_model.to(devices.cpu)
395398
shared.sd_model.first_stage_model.to(devices.cpu)
396399

@@ -531,6 +534,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
531534

532535
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
533536
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
537+
shared.parallel_processing_allowed = old_parallel_processing_allowed
534538

535539
return hypernetwork, filename
536540

modules/textual_inversion/textual_inversion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
273273
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
274274
with torch.autocast("cuda"):
275275
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
276+
277+
old_parallel_processing_allowed = shared.parallel_processing_allowed
278+
276279
if unload:
280+
shared.parallel_processing_allowed = False
277281
shared.sd_model.first_stage_model.to(devices.cpu)
278282

279283
embedding.vec.requires_grad = True
@@ -410,6 +414,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
410414
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
411415
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
412416
shared.sd_model.first_stage_model.to(devices.device)
417+
shared.parallel_processing_allowed = old_parallel_processing_allowed
413418

414419
return embedding, filename
415420

0 commit comments

Comments
 (0)