|
59 | 59 |
|
60 | 60 | logger = get_logger(__name__) |
61 | 61 | if is_torch_npu_available(): |
| 62 | + import torch_npu |
| 63 | + |
62 | 64 | torch.npu.config.allow_internal_format = False |
63 | 65 |
|
64 | 66 | DATASET_NAME_MAPPING = { |
@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae): |
540 | 542 | with torch.no_grad(): |
541 | 543 | model_input = vae.encode(pixel_values).latent_dist.sample() |
542 | 544 | model_input = model_input * vae.config.scaling_factor |
| 545 | + |
| 546 | + # There might have slightly performance improvement |
| 547 | + # by changing model_input.cpu() to accelerator.gather(model_input) |
543 | 548 | return {"model_input": model_input.cpu()} |
544 | 549 |
|
545 | 550 |
|
@@ -935,7 +940,10 @@ def preprocess_train(examples): |
935 | 940 | del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two |
936 | 941 | del text_encoders, tokenizers, vae |
937 | 942 | gc.collect() |
938 | | - torch.cuda.empty_cache() |
| 943 | + if is_torch_npu_available(): |
| 944 | + torch_npu.npu.empty_cache() |
| 945 | + elif torch.cuda.is_available(): |
| 946 | + torch.cuda.empty_cache() |
939 | 947 |
|
940 | 948 | def collate_fn(examples): |
941 | 949 | model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) |
@@ -1091,8 +1099,7 @@ def compute_time_ids(original_size, crops_coords_top_left): |
1091 | 1099 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids |
1092 | 1100 | target_size = (args.resolution, args.resolution) |
1093 | 1101 | add_time_ids = list(original_size + crops_coords_top_left + target_size) |
1094 | | - add_time_ids = torch.tensor([add_time_ids]) |
1095 | | - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) |
| 1102 | + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) |
1096 | 1103 | return add_time_ids |
1097 | 1104 |
|
1098 | 1105 | add_time_ids = torch.cat( |
@@ -1261,7 +1268,10 @@ def compute_time_ids(original_size, crops_coords_top_left): |
1261 | 1268 | ) |
1262 | 1269 |
|
1263 | 1270 | del pipeline |
1264 | | - torch.cuda.empty_cache() |
| 1271 | + if is_torch_npu_available(): |
| 1272 | + torch_npu.npu.empty_cache() |
| 1273 | + elif torch.cuda.is_available(): |
| 1274 | + torch.cuda.empty_cache() |
1265 | 1275 |
|
1266 | 1276 | if args.use_ema: |
1267 | 1277 | # Switch back to the original UNet parameters. |
|
0 commit comments