5959
6060logger  =  get_logger (__name__ )
6161if  is_torch_npu_available ():
62+     import  torch_npu 
6263    torch .npu .config .allow_internal_format  =  False 
64+     torch .npu .set_compile_mode (jit_compile = False )
6365
6466DATASET_NAME_MAPPING  =  {
6567    "lambdalabs/naruto-blip-captions" : ("image" , "text" ),
@@ -531,7 +533,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
531533    return  {"prompt_embeds" : prompt_embeds .cpu (), "pooled_prompt_embeds" : pooled_prompt_embeds .cpu ()}
532534
533535
534- def  compute_vae_encodings (batch , vae ):
536+ def  compute_vae_encodings (batch , accelerator ,  vae ):
535537    images  =  batch .pop ("pixel_values" )
536538    pixel_values  =  torch .stack (list (images ))
537539    pixel_values  =  pixel_values .to (memory_format = torch .contiguous_format ).float ()
@@ -540,7 +542,7 @@ def compute_vae_encodings(batch, vae):
540542    with  torch .no_grad ():
541543        model_input  =  vae .encode (pixel_values ).latent_dist .sample ()
542544    model_input  =  model_input  *  vae .config .scaling_factor 
543-     return  {"model_input" : model_input . cpu ( )}
545+     return  {"model_input" : accelerator . gather ( model_input )}
544546
545547
546548def  generate_timestep_weights (args , num_timesteps ):
@@ -910,7 +912,7 @@ def preprocess_train(examples):
910912        proportion_empty_prompts = args .proportion_empty_prompts ,
911913        caption_column = args .caption_column ,
912914    )
913-     compute_vae_encodings_fn  =  functools .partial (compute_vae_encodings , vae = vae )
915+     compute_vae_encodings_fn  =  functools .partial (compute_vae_encodings , accelerator = accelerator ,  vae = vae )
914916    with  accelerator .main_process_first ():
915917        from  datasets .fingerprint  import  Hasher 
916918
@@ -935,7 +937,10 @@ def preprocess_train(examples):
935937    del  compute_vae_encodings_fn , compute_embeddings_fn , text_encoder_one , text_encoder_two 
936938    del  text_encoders , tokenizers , vae 
937939    gc .collect ()
938-     torch .cuda .empty_cache ()
940+     if  is_torch_npu_available ():
941+         torch_npu .npu .empty_cache ()
942+     else :
943+         torch .cuda .empty_cache ()
939944
940945    def  collate_fn (examples ):
941946        model_input  =  torch .stack ([torch .tensor (example ["model_input" ]) for  example  in  examples ])
@@ -1091,8 +1096,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911096                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 
10921097                    target_size  =  (args .resolution , args .resolution )
10931098                    add_time_ids  =  list (original_size  +  crops_coords_top_left  +  target_size )
1094-                     add_time_ids  =  torch .tensor ([add_time_ids ])
1095-                     add_time_ids  =  add_time_ids .to (accelerator .device , dtype = weight_dtype )
1099+                     add_time_ids  =  torch .tensor ([add_time_ids ], device = accelerator .device , dtype = weight_dtype )
10961100                    return  add_time_ids 
10971101
10981102                add_time_ids  =  torch .cat (
@@ -1261,7 +1265,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611265                        )
12621266
12631267                del  pipeline 
1264-                 torch .cuda .empty_cache ()
1268+                 if  is_torch_npu_available ():
1269+                     torch_npu .npu .empty_cache ()
1270+                 else :
1271+                     torch .cuda .empty_cache ()
12651272
12661273                if  args .use_ema :
12671274                    # Switch back to the original UNet parameters. 
0 commit comments