@@ -376,6 +376,7 @@ def __call__(
376376
377377 # 2. Define call parameters
378378 batch_size = 1 if isinstance (prompt , str ) else len (prompt )
379+ device = self ._execution_device
379380
380381 if editing_prompt :
381382 enable_edit_guidance = True
@@ -405,7 +406,7 @@ def __call__(
405406 f" { self .tokenizer .model_max_length } tokens: { removed_text } "
406407 )
407408 text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
408- text_embeddings = self .text_encoder (text_input_ids .to (self . device ))[0 ]
409+ text_embeddings = self .text_encoder (text_input_ids .to (device ))[0 ]
409410
410411 # duplicate text embeddings for each generation per prompt, using mps friendly method
411412 bs_embed , seq_len , _ = text_embeddings .shape
@@ -433,9 +434,9 @@ def __call__(
433434 f" { self .tokenizer .model_max_length } tokens: { removed_text } "
434435 )
435436 edit_concepts_input_ids = edit_concepts_input_ids [:, : self .tokenizer .model_max_length ]
436- edit_concepts = self .text_encoder (edit_concepts_input_ids .to (self . device ))[0 ]
437+ edit_concepts = self .text_encoder (edit_concepts_input_ids .to (device ))[0 ]
437438 else :
438- edit_concepts = editing_prompt_embeddings .to (self . device ).repeat (batch_size , 1 , 1 )
439+ edit_concepts = editing_prompt_embeddings .to (device ).repeat (batch_size , 1 , 1 )
439440
440441 # duplicate text embeddings for each generation per prompt, using mps friendly method
441442 bs_embed_edit , seq_len_edit , _ = edit_concepts .shape
@@ -476,7 +477,7 @@ def __call__(
476477 truncation = True ,
477478 return_tensors = "pt" ,
478479 )
479- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self . device ))[0 ]
480+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (device ))[0 ]
480481
481482 # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
482483 seq_len = uncond_embeddings .shape [1 ]
@@ -493,7 +494,7 @@ def __call__(
493494 # get the initial random noise unless the user supplied it
494495
495496 # 4. Prepare timesteps
496- self .scheduler .set_timesteps (num_inference_steps , device = self . device )
497+ self .scheduler .set_timesteps (num_inference_steps , device = device )
497498 timesteps = self .scheduler .timesteps
498499
499500 # 5. Prepare latent variables
@@ -504,7 +505,7 @@ def __call__(
504505 height ,
505506 width ,
506507 text_embeddings .dtype ,
507- self . device ,
508+ device ,
508509 generator ,
509510 latents ,
510511 )
@@ -562,12 +563,12 @@ def __call__(
562563 if enable_edit_guidance :
563564 concept_weights = torch .zeros (
564565 (len (noise_pred_edit_concepts ), noise_guidance .shape [0 ]),
565- device = self . device ,
566+ device = device ,
566567 dtype = noise_guidance .dtype ,
567568 )
568569 noise_guidance_edit = torch .zeros (
569570 (len (noise_pred_edit_concepts ), * noise_guidance .shape ),
570- device = self . device ,
571+ device = device ,
571572 dtype = noise_guidance .dtype ,
572573 )
573574 # noise_guidance_edit = torch.zeros_like(noise_guidance)
@@ -644,21 +645,19 @@ def __call__(
644645
645646 # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
646647
647- warmup_inds = torch .tensor (warmup_inds ).to (self . device )
648+ warmup_inds = torch .tensor (warmup_inds ).to (device )
648649 if len (noise_pred_edit_concepts ) > warmup_inds .shape [0 ] > 0 :
649650 concept_weights = concept_weights .to ("cpu" ) # Offload to cpu
650651 noise_guidance_edit = noise_guidance_edit .to ("cpu" )
651652
652- concept_weights_tmp = torch .index_select (concept_weights .to (self . device ), 0 , warmup_inds )
653+ concept_weights_tmp = torch .index_select (concept_weights .to (device ), 0 , warmup_inds )
653654 concept_weights_tmp = torch .where (
654655 concept_weights_tmp < 0 , torch .zeros_like (concept_weights_tmp ), concept_weights_tmp
655656 )
656657 concept_weights_tmp = concept_weights_tmp / concept_weights_tmp .sum (dim = 0 )
657658 # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
658659
659- noise_guidance_edit_tmp = torch .index_select (
660- noise_guidance_edit .to (self .device ), 0 , warmup_inds
661- )
660+ noise_guidance_edit_tmp = torch .index_select (noise_guidance_edit .to (device ), 0 , warmup_inds )
662661 noise_guidance_edit_tmp = torch .einsum (
663662 "cb,cbijk->bijk" , concept_weights_tmp , noise_guidance_edit_tmp
664663 )
@@ -669,8 +668,8 @@ def __call__(
669668
670669 del noise_guidance_edit_tmp
671670 del concept_weights_tmp
672- concept_weights = concept_weights .to (self . device )
673- noise_guidance_edit = noise_guidance_edit .to (self . device )
671+ concept_weights = concept_weights .to (device )
672+ noise_guidance_edit = noise_guidance_edit .to (device )
674673
675674 concept_weights = torch .where (
676675 concept_weights < 0 , torch .zeros_like (concept_weights ), concept_weights
@@ -679,6 +678,7 @@ def __call__(
679678 concept_weights = torch .nan_to_num (concept_weights )
680679
681680 noise_guidance_edit = torch .einsum ("cb,cbijk->bijk" , concept_weights , noise_guidance_edit )
681+ noise_guidance_edit = noise_guidance_edit .to (edit_momentum .device )
682682
683683 noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
684684
@@ -689,7 +689,7 @@ def __call__(
689689 self .sem_guidance [i ] = noise_guidance_edit .detach ().cpu ()
690690
691691 if sem_guidance is not None :
692- edit_guidance = sem_guidance [i ].to (self . device )
692+ edit_guidance = sem_guidance [i ].to (device )
693693 noise_guidance = noise_guidance + edit_guidance
694694
695695 noise_pred = noise_pred_uncond + noise_guidance
@@ -705,7 +705,7 @@ def __call__(
705705 # 8. Post-processing
706706 if not output_type == "latent" :
707707 image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
708- image , has_nsfw_concept = self .run_safety_checker (image , self . device , text_embeddings .dtype )
708+ image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
709709 else :
710710 image = latents
711711 has_nsfw_concept = None
0 commit comments