@@ -376,6 +376,7 @@ def __call__(
376
376
377
377
# 2. Define call parameters
378
378
batch_size = 1 if isinstance (prompt , str ) else len (prompt )
379
+ device = self ._execution_device
379
380
380
381
if editing_prompt :
381
382
enable_edit_guidance = True
@@ -405,7 +406,7 @@ def __call__(
405
406
f" { self .tokenizer .model_max_length } tokens: { removed_text } "
406
407
)
407
408
text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
408
- text_embeddings = self .text_encoder (text_input_ids .to (self . device ))[0 ]
409
+ text_embeddings = self .text_encoder (text_input_ids .to (device ))[0 ]
409
410
410
411
# duplicate text embeddings for each generation per prompt, using mps friendly method
411
412
bs_embed , seq_len , _ = text_embeddings .shape
@@ -433,9 +434,9 @@ def __call__(
433
434
f" { self .tokenizer .model_max_length } tokens: { removed_text } "
434
435
)
435
436
edit_concepts_input_ids = edit_concepts_input_ids [:, : self .tokenizer .model_max_length ]
436
- edit_concepts = self .text_encoder (edit_concepts_input_ids .to (self . device ))[0 ]
437
+ edit_concepts = self .text_encoder (edit_concepts_input_ids .to (device ))[0 ]
437
438
else :
438
- edit_concepts = editing_prompt_embeddings .to (self . device ).repeat (batch_size , 1 , 1 )
439
+ edit_concepts = editing_prompt_embeddings .to (device ).repeat (batch_size , 1 , 1 )
439
440
440
441
# duplicate text embeddings for each generation per prompt, using mps friendly method
441
442
bs_embed_edit , seq_len_edit , _ = edit_concepts .shape
@@ -476,7 +477,7 @@ def __call__(
476
477
truncation = True ,
477
478
return_tensors = "pt" ,
478
479
)
479
- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self . device ))[0 ]
480
+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (device ))[0 ]
480
481
481
482
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
482
483
seq_len = uncond_embeddings .shape [1 ]
@@ -493,7 +494,7 @@ def __call__(
493
494
# get the initial random noise unless the user supplied it
494
495
495
496
# 4. Prepare timesteps
496
- self .scheduler .set_timesteps (num_inference_steps , device = self . device )
497
+ self .scheduler .set_timesteps (num_inference_steps , device = device )
497
498
timesteps = self .scheduler .timesteps
498
499
499
500
# 5. Prepare latent variables
@@ -504,7 +505,7 @@ def __call__(
504
505
height ,
505
506
width ,
506
507
text_embeddings .dtype ,
507
- self . device ,
508
+ device ,
508
509
generator ,
509
510
latents ,
510
511
)
@@ -562,12 +563,12 @@ def __call__(
562
563
if enable_edit_guidance :
563
564
concept_weights = torch .zeros (
564
565
(len (noise_pred_edit_concepts ), noise_guidance .shape [0 ]),
565
- device = self . device ,
566
+ device = device ,
566
567
dtype = noise_guidance .dtype ,
567
568
)
568
569
noise_guidance_edit = torch .zeros (
569
570
(len (noise_pred_edit_concepts ), * noise_guidance .shape ),
570
- device = self . device ,
571
+ device = device ,
571
572
dtype = noise_guidance .dtype ,
572
573
)
573
574
# noise_guidance_edit = torch.zeros_like(noise_guidance)
@@ -644,21 +645,19 @@ def __call__(
644
645
645
646
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
646
647
647
- warmup_inds = torch .tensor (warmup_inds ).to (self . device )
648
+ warmup_inds = torch .tensor (warmup_inds ).to (device )
648
649
if len (noise_pred_edit_concepts ) > warmup_inds .shape [0 ] > 0 :
649
650
concept_weights = concept_weights .to ("cpu" ) # Offload to cpu
650
651
noise_guidance_edit = noise_guidance_edit .to ("cpu" )
651
652
652
- concept_weights_tmp = torch .index_select (concept_weights .to (self . device ), 0 , warmup_inds )
653
+ concept_weights_tmp = torch .index_select (concept_weights .to (device ), 0 , warmup_inds )
653
654
concept_weights_tmp = torch .where (
654
655
concept_weights_tmp < 0 , torch .zeros_like (concept_weights_tmp ), concept_weights_tmp
655
656
)
656
657
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp .sum (dim = 0 )
657
658
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
658
659
659
- noise_guidance_edit_tmp = torch .index_select (
660
- noise_guidance_edit .to (self .device ), 0 , warmup_inds
661
- )
660
+ noise_guidance_edit_tmp = torch .index_select (noise_guidance_edit .to (device ), 0 , warmup_inds )
662
661
noise_guidance_edit_tmp = torch .einsum (
663
662
"cb,cbijk->bijk" , concept_weights_tmp , noise_guidance_edit_tmp
664
663
)
@@ -669,8 +668,8 @@ def __call__(
669
668
670
669
del noise_guidance_edit_tmp
671
670
del concept_weights_tmp
672
- concept_weights = concept_weights .to (self . device )
673
- noise_guidance_edit = noise_guidance_edit .to (self . device )
671
+ concept_weights = concept_weights .to (device )
672
+ noise_guidance_edit = noise_guidance_edit .to (device )
674
673
675
674
concept_weights = torch .where (
676
675
concept_weights < 0 , torch .zeros_like (concept_weights ), concept_weights
@@ -679,6 +678,7 @@ def __call__(
679
678
concept_weights = torch .nan_to_num (concept_weights )
680
679
681
680
noise_guidance_edit = torch .einsum ("cb,cbijk->bijk" , concept_weights , noise_guidance_edit )
681
+ noise_guidance_edit = noise_guidance_edit .to (edit_momentum .device )
682
682
683
683
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
684
684
@@ -689,7 +689,7 @@ def __call__(
689
689
self .sem_guidance [i ] = noise_guidance_edit .detach ().cpu ()
690
690
691
691
if sem_guidance is not None :
692
- edit_guidance = sem_guidance [i ].to (self . device )
692
+ edit_guidance = sem_guidance [i ].to (device )
693
693
noise_guidance = noise_guidance + edit_guidance
694
694
695
695
noise_pred = noise_pred_uncond + noise_guidance
@@ -705,7 +705,7 @@ def __call__(
705
705
# 8. Post-processing
706
706
if not output_type == "latent" :
707
707
image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
708
- image , has_nsfw_concept = self .run_safety_checker (image , self . device , text_embeddings .dtype )
708
+ image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
709
709
else :
710
710
image = latents
711
711
has_nsfw_concept = None
0 commit comments