Skip to content

Commit c1b0f3d

Browse files
hlkyyiyixuxu
authored andcommitted
no expand attn mask, two forward pass
1 parent 565fa0c commit c1b0f3d

File tree

3 files changed

+24
-32
lines changed

3 files changed

+24
-32
lines changed

src/diffusers/models/normalization.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -515,19 +515,15 @@ def forward(self, hidden_states):
515515
if self.bias is not None:
516516
hidden_states = hidden_states + self.bias
517517
else:
518-
input_dtype = hidden_states.dtype
519-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
520-
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
521-
522518
if self.weight is not None:
523519
# convert into half-precision if necessary
524520
if self.weight.dtype in [torch.float16, torch.bfloat16]:
525521
hidden_states = hidden_states.to(self.weight.dtype)
526-
hidden_states = hidden_states * self.weight
527-
if self.bias is not None:
528-
hidden_states = hidden_states + self.bias
529-
else:
530-
hidden_states = hidden_states.to(input_dtype)
522+
hidden_states = nn.functional.rms_norm(
523+
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
524+
)
525+
if self.bias is not None:
526+
hidden_states = hidden_states + self.bias
531527

532528
return hidden_states
533529

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __call__(
130130
# scaled_dot_product_attention expects attention_mask shape to be
131131
# (batch, heads, source_length, target_length)
132132
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
133-
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
134133

135134
query = query.transpose(1, 2)
136135
key = key.transpose(1, 2)
@@ -493,10 +492,12 @@ def forward(
493492

494493
# 2. Context & noise refinement
495494
for layer in self.context_refiner:
496-
encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb)
495+
# NOTE: mask not used for performance
496+
encoder_hidden_states = layer(encoder_hidden_states, None, encoder_rotary_emb)
497497

498498
for layer in self.noise_refiner:
499-
hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb)
499+
# NOTE: mask not used for performance
500+
hidden_states = layer(hidden_states, None, hidden_rotary_emb, temb)
500501

501502
# 3. Attention mask preparation
502503
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
@@ -511,10 +512,11 @@ def forward(
511512

512513
# 4. Transformer blocks
513514
for layer in self.layers:
515+
# NOTE: mask not used for performance
514516
if torch.is_grad_enabled() and self.gradient_checkpointing:
515-
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb)
517+
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, None, joint_rotary_emb, temb)
516518
else:
517-
hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb)
519+
hidden_states = layer(hidden_states, None, joint_rotary_emb, temb)
518520

519521
# 5. Output norm & projection & unpatchify
520522
hidden_states = self.norm_out(hidden_states, temb)

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,6 @@ def __call__(
658658
max_sequence_length=max_sequence_length,
659659
system_prompt=system_prompt,
660660
)
661-
if do_classifier_free_guidance:
662-
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
663-
prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0)
664661

665662
# 4. Prepare latents.
666663
latent_channels = self.transformer.config.in_channels
@@ -700,22 +697,13 @@ def __call__(
700697
for i, t in enumerate(timesteps):
701698
# compute whether apply classifier-free truncation on this timestep
702699
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
703-
704-
# expand the latents if we are doing classifier free guidance
705-
latent_model_input = (
706-
torch.cat([latents] * 2)
707-
if do_classifier_free_guidance and not do_classifier_free_truncation
708-
else latents
709-
)
710-
711-
current_timestep = t
712-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
713-
current_timestep = current_timestep.expand(latent_model_input.shape[0])
714700
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
715-
current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps
701+
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
702+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
703+
current_timestep = current_timestep.expand(latents.shape[0])
716704

717-
noise_pred = self.transformer(
718-
hidden_states=latent_model_input,
705+
noise_pred_cond = self.transformer(
706+
hidden_states=latents,
719707
timestep=current_timestep,
720708
encoder_hidden_states=prompt_embeds,
721709
attention_mask=prompt_attention_mask,
@@ -724,7 +712,13 @@ def __call__(
724712

725713
# perform normalization-based guidance scale on a truncated timestep interval
726714
if self.do_classifier_free_guidance and not do_classifier_free_truncation:
727-
noise_pred_cond, noise_pred_uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
715+
noise_pred_uncond = self.transformer(
716+
hidden_states=latents,
717+
timestep=current_timestep,
718+
encoder_hidden_states=negative_prompt_embeds,
719+
attention_mask=negative_prompt_attention_mask,
720+
return_dict=False,
721+
)[0]
728722
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
729723
# apply normalization after classifier-free guidance
730724
if cfg_normalization:

0 commit comments

Comments
 (0)