|
15 | 15 |
|
16 | 16 | import argparse
|
17 | 17 | import copy
|
18 |
| -import gc |
19 | 18 | import itertools
|
20 | 19 | import logging
|
21 | 20 | import math
|
|
56 | 55 | from diffusers.training_utils import (
|
57 | 56 | _set_state_dict_into_text_encoder,
|
58 | 57 | cast_training_params,
|
| 58 | + clear_objs_and_retain_memory, |
59 | 59 | compute_density_for_timestep_sampling,
|
60 | 60 | compute_loss_weighting_for_sd3,
|
61 | 61 | )
|
@@ -210,9 +210,7 @@ def log_validation(
|
210 | 210 | }
|
211 | 211 | )
|
212 | 212 |
|
213 |
| - del pipeline |
214 |
| - if torch.cuda.is_available(): |
215 |
| - torch.cuda.empty_cache() |
| 213 | + clear_objs_and_retain_memory(objs=[pipeline]) |
216 | 214 |
|
217 | 215 | return images
|
218 | 216 |
|
@@ -1107,9 +1105,7 @@ def main(args):
|
1107 | 1105 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
1108 | 1106 | image.save(image_filename)
|
1109 | 1107 |
|
1110 |
| - del pipeline |
1111 |
| - if torch.cuda.is_available(): |
1112 |
| - torch.cuda.empty_cache() |
| 1108 | + clear_objs_and_retain_memory(objs=[pipeline]) |
1113 | 1109 |
|
1114 | 1110 | # Handle the repository creation
|
1115 | 1111 | if accelerator.is_main_process:
|
@@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
1455 | 1451 |
|
1456 | 1452 | # Clear the memory here
|
1457 | 1453 | if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
1458 |
| - del tokenizers, text_encoders |
1459 | 1454 | # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
1460 |
| - del text_encoder_one, text_encoder_two, text_encoder_three |
1461 |
| - gc.collect() |
1462 |
| - if torch.cuda.is_available(): |
1463 |
| - torch.cuda.empty_cache() |
| 1455 | + clear_objs_and_retain_memory( |
| 1456 | + objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] |
| 1457 | + ) |
1464 | 1458 |
|
1465 | 1459 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
1466 | 1460 | # pack the statically computed variables appropriately here. This is so that we don't
|
@@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
1795 | 1789 | pipeline_args=pipeline_args,
|
1796 | 1790 | epoch=epoch,
|
1797 | 1791 | )
|
| 1792 | + objs = [] |
1798 | 1793 | if not args.train_text_encoder:
|
1799 |
| - del text_encoder_one, text_encoder_two, text_encoder_three |
| 1794 | + objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) |
1800 | 1795 |
|
1801 |
| - torch.cuda.empty_cache() |
1802 |
| - gc.collect() |
| 1796 | + clear_objs_and_retain_memory(objs=objs) |
1803 | 1797 |
|
1804 | 1798 | # Save the lora layers
|
1805 | 1799 | accelerator.wait_for_everyone()
|
|
0 commit comments