5959
6060logger = get_logger (__name__ )
6161if is_torch_npu_available ():
62+ import torch_npu
6263 torch .npu .config .allow_internal_format = False
64+ torch .npu .set_compile_mode (jit_compile = False )
6365
6466DATASET_NAME_MAPPING = {
6567 "lambdalabs/naruto-blip-captions" : ("image" , "text" ),
@@ -531,7 +533,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
531533 return {"prompt_embeds" : prompt_embeds .cpu (), "pooled_prompt_embeds" : pooled_prompt_embeds .cpu ()}
532534
533535
534- def compute_vae_encodings (batch , vae ):
536+ def compute_vae_encodings (batch , accelerator , vae ):
535537 images = batch .pop ("pixel_values" )
536538 pixel_values = torch .stack (list (images ))
537539 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
@@ -540,7 +542,7 @@ def compute_vae_encodings(batch, vae):
540542 with torch .no_grad ():
541543 model_input = vae .encode (pixel_values ).latent_dist .sample ()
542544 model_input = model_input * vae .config .scaling_factor
543- return {"model_input" : model_input . cpu ( )}
545+ return {"model_input" : accelerator . gather ( model_input )}
544546
545547
546548def generate_timestep_weights (args , num_timesteps ):
@@ -910,7 +912,7 @@ def preprocess_train(examples):
910912 proportion_empty_prompts = args .proportion_empty_prompts ,
911913 caption_column = args .caption_column ,
912914 )
913- compute_vae_encodings_fn = functools .partial (compute_vae_encodings , vae = vae )
915+ compute_vae_encodings_fn = functools .partial (compute_vae_encodings , accelerator = accelerator , vae = vae )
914916 with accelerator .main_process_first ():
915917 from datasets .fingerprint import Hasher
916918
@@ -935,7 +937,10 @@ def preprocess_train(examples):
935937 del compute_vae_encodings_fn , compute_embeddings_fn , text_encoder_one , text_encoder_two
936938 del text_encoders , tokenizers , vae
937939 gc .collect ()
938- torch .cuda .empty_cache ()
940+ if is_torch_npu_available ():
941+ torch_npu .npu .empty_cache ()
942+ else :
943+ torch .cuda .empty_cache ()
939944
940945 def collate_fn (examples ):
941946 model_input = torch .stack ([torch .tensor (example ["model_input" ]) for example in examples ])
@@ -1091,8 +1096,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911096 # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
10921097 target_size = (args .resolution , args .resolution )
10931098 add_time_ids = list (original_size + crops_coords_top_left + target_size )
1094- add_time_ids = torch .tensor ([add_time_ids ])
1095- add_time_ids = add_time_ids .to (accelerator .device , dtype = weight_dtype )
1099+ add_time_ids = torch .tensor ([add_time_ids ], device = accelerator .device , dtype = weight_dtype )
10961100 return add_time_ids
10971101
10981102 add_time_ids = torch .cat (
@@ -1261,7 +1265,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611265 )
12621266
12631267 del pipeline
1264- torch .cuda .empty_cache ()
1268+ if is_torch_npu_available ():
1269+ torch_npu .npu .empty_cache ()
1270+ else :
1271+ torch .cuda .empty_cache ()
12651272
12661273 if args .use_ema :
12671274 # Switch back to the original UNet parameters.
0 commit comments