| 
57 | 57 |     is_wandb_available,  | 
58 | 58 | )  | 
59 | 59 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card  | 
 | 60 | +from diffusers.utils.import_utils import is_torch_npu_available  | 
60 | 61 | from diffusers.utils.torch_utils import is_compiled_module  | 
61 | 62 | 
 
  | 
62 | 63 | 
 
  | 
 | 
68 | 69 | 
 
  | 
69 | 70 | logger = get_logger(__name__)  | 
70 | 71 | 
 
  | 
 | 72 | +if is_torch_npu_available():  | 
 | 73 | +    import torch_npu  | 
 | 74 | + | 
 | 75 | +    torch.npu.config.allow_internal_format = False  | 
 | 76 | +    torch.npu.set_compile_mode(jit_compile=False)  | 
 | 77 | + | 
71 | 78 | 
 
  | 
72 | 79 | def save_model_card(  | 
73 | 80 |     repo_id: str,  | 
@@ -189,6 +196,8 @@ def log_validation(  | 
189 | 196 |     del pipeline  | 
190 | 197 |     if torch.cuda.is_available():  | 
191 | 198 |         torch.cuda.empty_cache()  | 
 | 199 | +    elif is_torch_npu_available():  | 
 | 200 | +        torch_npu.npu.empty_cache()  | 
192 | 201 | 
 
  | 
193 | 202 |     return images  | 
194 | 203 | 
 
  | 
@@ -1035,7 +1044,9 @@ def main(args):  | 
1035 | 1044 |         cur_class_images = len(list(class_images_dir.iterdir()))  | 
1036 | 1045 | 
 
  | 
1037 | 1046 |         if cur_class_images < args.num_class_images:  | 
1038 |  | -            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()  | 
 | 1047 | +            has_supported_fp16_accelerator = (  | 
 | 1048 | +                torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()  | 
 | 1049 | +            )  | 
1039 | 1050 |             torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32  | 
1040 | 1051 |             if args.prior_generation_precision == "fp32":  | 
1041 | 1052 |                 torch_dtype = torch.float32  | 
@@ -1073,6 +1084,8 @@ def main(args):  | 
1073 | 1084 |             del pipeline  | 
1074 | 1085 |             if torch.cuda.is_available():  | 
1075 | 1086 |                 torch.cuda.empty_cache()  | 
 | 1087 | +            elif is_torch_npu_available():  | 
 | 1088 | +                torch_npu.npu.empty_cache()  | 
1076 | 1089 | 
 
  | 
1077 | 1090 |     # Handle the repository creation  | 
1078 | 1091 |     if accelerator.is_main_process:  | 
@@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):  | 
1354 | 1367 |         gc.collect()  | 
1355 | 1368 |         if torch.cuda.is_available():  | 
1356 | 1369 |             torch.cuda.empty_cache()  | 
 | 1370 | +        elif is_torch_npu_available():  | 
 | 1371 | +            torch_npu.npu.empty_cache()  | 
1357 | 1372 | 
 
  | 
1358 | 1373 |     # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),  | 
1359 | 1374 |     # pack the statically computed variables appropriately here. This is so that we don't  | 
@@ -1719,9 +1734,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):  | 
1719 | 1734 |                 )  | 
1720 | 1735 |                 if not args.train_text_encoder:  | 
1721 | 1736 |                     del text_encoder_one, text_encoder_two  | 
1722 |  | -                    torch.cuda.empty_cache()  | 
 | 1737 | +                    if torch.cuda.is_available():  | 
 | 1738 | +                        torch.cuda.empty_cache()  | 
 | 1739 | +                    elif is_torch_npu_available():  | 
 | 1740 | +                        torch_npu.npu.empty_cache()  | 
1723 | 1741 |                     gc.collect()  | 
1724 | 1742 | 
 
  | 
 | 1743 | +                images = None  | 
 | 1744 | +                del pipeline  | 
 | 1745 | + | 
1725 | 1746 |     # Save the lora layers  | 
1726 | 1747 |     accelerator.wait_for_everyone()  | 
1727 | 1748 |     if accelerator.is_main_process:  | 
@@ -1780,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):  | 
1780 | 1801 |                 ignore_patterns=["step_*", "epoch_*"],  | 
1781 | 1802 |             )  | 
1782 | 1803 | 
 
  | 
 | 1804 | +        images = None  | 
 | 1805 | +        del pipeline  | 
 | 1806 | + | 
1783 | 1807 |     accelerator.end_training()  | 
1784 | 1808 | 
 
  | 
1785 | 1809 | 
 
  | 
 | 
0 commit comments