Skip to content

Commit 8ba90aa

Browse files
authored
chore: add a cleaning utility to be useful during training. (#9240)
1 parent 9d49b45 commit 8ba90aa

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import copy
18-
import gc
1918
import itertools
2019
import logging
2120
import math
@@ -56,6 +55,7 @@
5655
from diffusers.training_utils import (
5756
_set_state_dict_into_text_encoder,
5857
cast_training_params,
58+
clear_objs_and_retain_memory,
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
6161
)
@@ -210,9 +210,7 @@ def log_validation(
210210
}
211211
)
212212

213-
del pipeline
214-
if torch.cuda.is_available():
215-
torch.cuda.empty_cache()
213+
clear_objs_and_retain_memory(objs=[pipeline])
216214

217215
return images
218216

@@ -1107,9 +1105,7 @@ def main(args):
11071105
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
11081106
image.save(image_filename)
11091107

1110-
del pipeline
1111-
if torch.cuda.is_available():
1112-
torch.cuda.empty_cache()
1108+
clear_objs_and_retain_memory(objs=[pipeline])
11131109

11141110
# Handle the repository creation
11151111
if accelerator.is_main_process:
@@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14551451

14561452
# Clear the memory here
14571453
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1458-
del tokenizers, text_encoders
14591454
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1460-
del text_encoder_one, text_encoder_two, text_encoder_three
1461-
gc.collect()
1462-
if torch.cuda.is_available():
1463-
torch.cuda.empty_cache()
1455+
clear_objs_and_retain_memory(
1456+
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
1457+
)
14641458

14651459
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14661460
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17951789
pipeline_args=pipeline_args,
17961790
epoch=epoch,
17971791
)
1792+
objs = []
17981793
if not args.train_text_encoder:
1799-
del text_encoder_one, text_encoder_two, text_encoder_three
1794+
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
18001795

1801-
torch.cuda.empty_cache()
1802-
gc.collect()
1796+
clear_objs_and_retain_memory(objs=objs)
18031797

18041798
# Save the lora layers
18051799
accelerator.wait_for_everyone()

src/diffusers/training_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import copy
3+
import gc
34
import math
45
import random
56
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
259260
return weighting
260261

261262

263+
def clear_objs_and_retain_memory(objs: List[Any]):
264+
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
265+
if len(objs) >= 1:
266+
for obj in objs:
267+
del obj
268+
269+
gc.collect()
270+
271+
if torch.cuda.is_available():
272+
torch.cuda.empty_cache()
273+
elif torch.backends.mps.is_available():
274+
torch.mps.empty_cache()
275+
elif is_torch_npu_available():
276+
torch_npu.empty_cache()
277+
278+
262279
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
263280
class EMAModel:
264281
"""

0 commit comments

Comments
 (0)