Skip to content

Commit 7235805

Browse files
authored
Revert cond + uncond batching
1 parent abf8a33 commit 7235805

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/diffusers/pipelines/chroma/pipeline_chroma.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,6 @@ def __call__(
694694
max_sequence_length=max_sequence_length,
695695
lora_scale=lora_scale,
696696
)
697-
698-
if self.do_classifier_free_guidance:
699-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
700697

701698
# 4. Prepare latent variables
702699
num_channels_latents = self.transformer.config.in_channels // 4
@@ -773,13 +770,11 @@ def __call__(
773770
if image_embeds is not None:
774771
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
775772

776-
# expand the latents if we are doing classifier free guidance
777-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
778773
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
779-
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
774+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
780775

781776
noise_pred = self.transformer(
782-
hidden_states=latent_model_input,
777+
hidden_states=latents,
783778
timestep=timestep / 1000,
784779
encoder_hidden_states=prompt_embeds,
785780
txt_ids=text_ids,
@@ -791,8 +786,16 @@ def __call__(
791786
if self.do_classifier_free_guidance:
792787
if negative_image_embeds is not None:
793788
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
794-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
795-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
789+
neg_noise_pred = self.transformer(
790+
hidden_states=latents,
791+
timestep=timestep / 1000,
792+
encoder_hidden_states=negative_prompt_embeds,
793+
txt_ids=negative_text_ids,
794+
img_ids=latent_image_ids,
795+
joint_attention_kwargs=self.joint_attention_kwargs,
796+
return_dict=False,
797+
)[0]
798+
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
796799

797800
# compute the previous noisy sample x_t -> x_t-1
798801
latents_dtype = latents.dtype

0 commit comments

Comments
 (0)