Skip to content

Commit 0028c34

Browse files
yiyixuxuyiyixuxusayakpaul
authored
fix SEGA pipeline (#8467)
* fix * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <[email protected]>
1 parent d457bee commit 0028c34

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def __call__(
376376

377377
# 2. Define call parameters
378378
batch_size = 1 if isinstance(prompt, str) else len(prompt)
379+
device = self._execution_device
379380

380381
if editing_prompt:
381382
enable_edit_guidance = True
@@ -405,7 +406,7 @@ def __call__(
405406
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
406407
)
407408
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
408-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
409+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
409410

410411
# duplicate text embeddings for each generation per prompt, using mps friendly method
411412
bs_embed, seq_len, _ = text_embeddings.shape
@@ -433,9 +434,9 @@ def __call__(
433434
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
434435
)
435436
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
436-
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
437+
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
437438
else:
438-
edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
439+
edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
439440

440441
# duplicate text embeddings for each generation per prompt, using mps friendly method
441442
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
@@ -476,7 +477,7 @@ def __call__(
476477
truncation=True,
477478
return_tensors="pt",
478479
)
479-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
480+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
480481

481482
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
482483
seq_len = uncond_embeddings.shape[1]
@@ -493,7 +494,7 @@ def __call__(
493494
# get the initial random noise unless the user supplied it
494495

495496
# 4. Prepare timesteps
496-
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
497+
self.scheduler.set_timesteps(num_inference_steps, device=device)
497498
timesteps = self.scheduler.timesteps
498499

499500
# 5. Prepare latent variables
@@ -504,7 +505,7 @@ def __call__(
504505
height,
505506
width,
506507
text_embeddings.dtype,
507-
self.device,
508+
device,
508509
generator,
509510
latents,
510511
)
@@ -562,12 +563,12 @@ def __call__(
562563
if enable_edit_guidance:
563564
concept_weights = torch.zeros(
564565
(len(noise_pred_edit_concepts), noise_guidance.shape[0]),
565-
device=self.device,
566+
device=device,
566567
dtype=noise_guidance.dtype,
567568
)
568569
noise_guidance_edit = torch.zeros(
569570
(len(noise_pred_edit_concepts), *noise_guidance.shape),
570-
device=self.device,
571+
device=device,
571572
dtype=noise_guidance.dtype,
572573
)
573574
# noise_guidance_edit = torch.zeros_like(noise_guidance)
@@ -644,21 +645,19 @@ def __call__(
644645

645646
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
646647

647-
warmup_inds = torch.tensor(warmup_inds).to(self.device)
648+
warmup_inds = torch.tensor(warmup_inds).to(device)
648649
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
649650
concept_weights = concept_weights.to("cpu") # Offload to cpu
650651
noise_guidance_edit = noise_guidance_edit.to("cpu")
651652

652-
concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
653+
concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
653654
concept_weights_tmp = torch.where(
654655
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
655656
)
656657
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
657658
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
658659

659-
noise_guidance_edit_tmp = torch.index_select(
660-
noise_guidance_edit.to(self.device), 0, warmup_inds
661-
)
660+
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
662661
noise_guidance_edit_tmp = torch.einsum(
663662
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
664663
)
@@ -669,8 +668,8 @@ def __call__(
669668

670669
del noise_guidance_edit_tmp
671670
del concept_weights_tmp
672-
concept_weights = concept_weights.to(self.device)
673-
noise_guidance_edit = noise_guidance_edit.to(self.device)
671+
concept_weights = concept_weights.to(device)
672+
noise_guidance_edit = noise_guidance_edit.to(device)
674673

675674
concept_weights = torch.where(
676675
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
@@ -679,6 +678,7 @@ def __call__(
679678
concept_weights = torch.nan_to_num(concept_weights)
680679

681680
noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
681+
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
682682

683683
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
684684

@@ -689,7 +689,7 @@ def __call__(
689689
self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
690690

691691
if sem_guidance is not None:
692-
edit_guidance = sem_guidance[i].to(self.device)
692+
edit_guidance = sem_guidance[i].to(device)
693693
noise_guidance = noise_guidance + edit_guidance
694694

695695
noise_pred = noise_pred_uncond + noise_guidance
@@ -705,7 +705,7 @@ def __call__(
705705
# 8. Post-processing
706706
if not output_type == "latent":
707707
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
708-
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
708+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
709709
else:
710710
image = latents
711711
has_nsfw_concept = None

0 commit comments

Comments
 (0)