Skip to content

Commit 5267414

Browse files
Merge pull request #4271 from MarkovInequality/racecond_fix
Fixes #4137 caused by race condition in training when VAE is unloaded
2 parents 5cd5a67 + c9a2cfd commit 5267414

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
433433

434434
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
435435

436+
old_parallel_processing_allowed = shared.parallel_processing_allowed
437+
436438
if unload:
439+
shared.parallel_processing_allowed = False
437440
shared.sd_model.cond_stage_model.to(devices.cpu)
438441
shared.sd_model.first_stage_model.to(devices.cpu)
439442

@@ -612,10 +615,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
612615
if shared.opts.save_optimizer_state:
613616
hypernetwork.optimizer_state_dict = optimizer.state_dict()
614617
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
618+
615619
del optimizer
616620
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
617621
shared.sd_model.cond_stage_model.to(devices.device)
618622
shared.sd_model.first_stage_model.to(devices.device)
623+
shared.parallel_processing_allowed = old_parallel_processing_allowed
619624

620625
return hypernetwork, filename
621626

modules/textual_inversion/textual_inversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
269269

270270
# dataset loading may take a while, so input validations and early returns should be done before this
271271
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
272+
old_parallel_processing_allowed = shared.parallel_processing_allowed
272273

273274
pin_memory = shared.opts.pin_memory
274275

@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
279280
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
280281

281282
if unload:
283+
shared.parallel_processing_allowed = False
282284
shared.sd_model.first_stage_model.to(devices.cpu)
283285

284286
embedding.vec.requires_grad = True
@@ -450,6 +452,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
450452
pbar.leave = False
451453
pbar.close()
452454
shared.sd_model.first_stage_model.to(devices.device)
455+
shared.parallel_processing_allowed = old_parallel_processing_allowed
453456

454457
return embedding, filename
455458

0 commit comments

Comments
 (0)