|
57 | 57 | is_wandb_available, |
58 | 58 | ) |
59 | 59 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
| 60 | +from diffusers.utils.import_utils import is_torch_npu_available |
60 | 61 | from diffusers.utils.torch_utils import is_compiled_module |
61 | 62 |
|
62 | 63 |
|
|
68 | 69 |
|
69 | 70 | logger = get_logger(__name__) |
70 | 71 |
|
| 72 | +if is_torch_npu_available(): |
| 73 | + import torch_npu |
| 74 | + |
| 75 | + torch.npu.config.allow_internal_format = False |
| 76 | + torch.npu.set_compile_mode(jit_compile=False) |
| 77 | + |
71 | 78 |
|
72 | 79 | def save_model_card( |
73 | 80 | repo_id: str, |
@@ -189,6 +196,8 @@ def log_validation( |
189 | 196 | del pipeline |
190 | 197 | if torch.cuda.is_available(): |
191 | 198 | torch.cuda.empty_cache() |
| 199 | + elif is_torch_npu_available(): |
| 200 | + torch_npu.npu.empty_cache() |
192 | 201 |
|
193 | 202 | return images |
194 | 203 |
|
@@ -1035,7 +1044,9 @@ def main(args): |
1035 | 1044 | cur_class_images = len(list(class_images_dir.iterdir())) |
1036 | 1045 |
|
1037 | 1046 | if cur_class_images < args.num_class_images: |
1038 | | - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() |
| 1047 | + has_supported_fp16_accelerator = ( |
| 1048 | + torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available() |
| 1049 | + ) |
1039 | 1050 | torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 |
1040 | 1051 | if args.prior_generation_precision == "fp32": |
1041 | 1052 | torch_dtype = torch.float32 |
@@ -1073,6 +1084,8 @@ def main(args): |
1073 | 1084 | del pipeline |
1074 | 1085 | if torch.cuda.is_available(): |
1075 | 1086 | torch.cuda.empty_cache() |
| 1087 | + elif is_torch_npu_available(): |
| 1088 | + torch_npu.npu.empty_cache() |
1076 | 1089 |
|
1077 | 1090 | # Handle the repository creation |
1078 | 1091 | if accelerator.is_main_process: |
@@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): |
1354 | 1367 | gc.collect() |
1355 | 1368 | if torch.cuda.is_available(): |
1356 | 1369 | torch.cuda.empty_cache() |
| 1370 | + elif is_torch_npu_available(): |
| 1371 | + torch_npu.npu.empty_cache() |
1357 | 1372 |
|
1358 | 1373 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), |
1359 | 1374 | # pack the statically computed variables appropriately here. This is so that we don't |
@@ -1719,9 +1734,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1719 | 1734 | ) |
1720 | 1735 | if not args.train_text_encoder: |
1721 | 1736 | del text_encoder_one, text_encoder_two |
1722 | | - torch.cuda.empty_cache() |
| 1737 | + if torch.cuda.is_available(): |
| 1738 | + torch.cuda.empty_cache() |
| 1739 | + elif is_torch_npu_available(): |
| 1740 | + torch_npu.npu.empty_cache() |
1723 | 1741 | gc.collect() |
1724 | 1742 |
|
| 1743 | + images = None |
| 1744 | + del pipeline |
| 1745 | + |
1725 | 1746 | # Save the lora layers |
1726 | 1747 | accelerator.wait_for_everyone() |
1727 | 1748 | if accelerator.is_main_process: |
@@ -1780,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1780 | 1801 | ignore_patterns=["step_*", "epoch_*"], |
1781 | 1802 | ) |
1782 | 1803 |
|
| 1804 | + images = None |
| 1805 | + del pipeline |
| 1806 | + |
1783 | 1807 | accelerator.end_training() |
1784 | 1808 |
|
1785 | 1809 |
|
|
0 commit comments