Skip to content

Commit f071a1d

Browse files
Merge pull request #4056 from MarkovInequality/TI_optimizations
Allow TI training using 6GB VRAM when xformers is available
2 parents 0e5d239 + 890e68a commit f071a1d

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

modules/shared.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,12 @@ def options_section(section_identifier, options_dict):
288288
}))
289289

290290
options_templates.update(options_section(('training', "Training"), {
291-
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
291+
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
292292
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
293293
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
294294
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
295295
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
296+
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
296297
}))
297298

298299
options_templates.update(options_section(('sd', "Stable Diffusion"), {

modules/textual_inversion/textual_inversion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
235235
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
236236

237237
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
238+
unload = shared.opts.unload_models_when_training
238239

239240
if save_embedding_every > 0:
240241
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
272273
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
273274
with torch.autocast("cuda"):
274275
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
276+
if unload:
277+
shared.sd_model.first_stage_model.to(devices.cpu)
275278

276279
embedding.vec.requires_grad = True
277280
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
328331
if images_dir is not None and steps_done % create_image_every == 0:
329332
forced_filename = f'{embedding_name}-{steps_done}'
330333
last_saved_image = os.path.join(images_dir, forced_filename)
334+
335+
shared.sd_model.first_stage_model.to(devices.device)
336+
331337
p = processing.StableDiffusionProcessingTxt2Img(
332338
sd_model=shared.sd_model,
333339
do_not_save_grid=True,
@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
355361
processed = processing.process_images(p)
356362
image = processed.images[0]
357363

364+
if unload:
365+
shared.sd_model.first_stage_model.to(devices.cpu)
366+
358367
shared.state.current_image = image
359368

360369
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
@@ -400,6 +409,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
400409

401410
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
402411
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
412+
shared.sd_model.first_stage_model.to(devices.device)
403413

404414
return embedding, filename
405415

modules/textual_inversion/ui.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def train_embedding(*args):
2525

2626
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
2727

28+
apply_optimizations = shared.opts.training_xattention_optimizations
2829
try:
29-
sd_hijack.undo_optimizations()
30+
if not apply_optimizations:
31+
sd_hijack.undo_optimizations()
3032

3133
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
3234

@@ -38,5 +40,6 @@ def train_embedding(*args):
3840
except Exception:
3941
raise
4042
finally:
41-
sd_hijack.apply_optimizations()
43+
if not apply_optimizations:
44+
sd_hijack.apply_optimizations()
4245

0 commit comments

Comments
 (0)