Skip to content

Commit 7c5a46f

Browse files
hlkyyiyixuxu
authored andcommitted
use_mask_in_transformer, is_torch_version
1 parent c1b0f3d commit 7c5a46f

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

src/diffusers/models/normalization.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def forward(self, hidden_states):
514514
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
515515
if self.bias is not None:
516516
hidden_states = hidden_states + self.bias
517-
else:
517+
elif is_torch_version(">=", "2.4"):
518518
if self.weight is not None:
519519
# convert into half-precision if necessary
520520
if self.weight.dtype in [torch.float16, torch.bfloat16]:
@@ -524,6 +524,20 @@ def forward(self, hidden_states):
524524
)
525525
if self.bias is not None:
526526
hidden_states = hidden_states + self.bias
527+
else:
528+
input_dtype = hidden_states.dtype
529+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
530+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
531+
532+
if self.weight is not None:
533+
# convert into half-precision if necessary
534+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
535+
hidden_states = hidden_states.to(self.weight.dtype)
536+
hidden_states = hidden_states * self.weight
537+
if self.bias is not None:
538+
hidden_states = hidden_states + self.bias
539+
else:
540+
hidden_states = hidden_states.to(input_dtype)
527541

528542
return hidden_states
529543

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ def forward(
469469
timestep: torch.Tensor,
470470
encoder_hidden_states: torch.Tensor,
471471
attention_mask: torch.Tensor,
472+
use_mask_in_transformer: bool = True,
472473
return_dict: bool = True,
473474
) -> Union[torch.Tensor, Transformer2DModelOutput]:
474475
batch_size = hidden_states.size(0)
@@ -493,11 +494,15 @@ def forward(
493494
# 2. Context & noise refinement
494495
for layer in self.context_refiner:
495496
# NOTE: mask not used for performance
496-
encoder_hidden_states = layer(encoder_hidden_states, None, encoder_rotary_emb)
497+
encoder_hidden_states = layer(
498+
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
499+
)
497500

498501
for layer in self.noise_refiner:
499502
# NOTE: mask not used for performance
500-
hidden_states = layer(hidden_states, None, hidden_rotary_emb, temb)
503+
hidden_states = layer(
504+
hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb
505+
)
501506

502507
# 3. Attention mask preparation
503508
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
@@ -514,9 +519,11 @@ def forward(
514519
for layer in self.layers:
515520
# NOTE: mask not used for performance
516521
if torch.is_grad_enabled() and self.gradient_checkpointing:
517-
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, None, joint_rotary_emb, temb)
522+
hidden_states = self._gradient_checkpointing_func(
523+
layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb
524+
)
518525
else:
519-
hidden_states = layer(hidden_states, None, joint_rotary_emb, temb)
526+
hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb)
520527

521528
# 5. Output norm & projection & unpatchify
522529
hidden_states = self.norm_out(hidden_states, temb)

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def __call__(
525525
system_prompt: Optional[str] = None,
526526
cfg_trunc_ratio: float = 1.0,
527527
cfg_normalization: bool = True,
528+
use_mask_in_transformer: bool = True,
528529
max_sequence_length: int = 256,
529530
) -> Union[ImagePipelineOutput, Tuple]:
530531
"""
@@ -596,6 +597,8 @@ def __call__(
596597
The ratio of the timestep interval to apply normalization-based guidance scale.
597598
cfg_normalization (`bool`, *optional*, defaults to `True`):
598599
Whether to apply normalization-based guidance scale.
600+
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
601+
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
599602
max_sequence_length (`int`, defaults to `256`):
600603
Maximum sequence length to use with the `prompt`.
601604
@@ -707,6 +710,7 @@ def __call__(
707710
timestep=current_timestep,
708711
encoder_hidden_states=prompt_embeds,
709712
attention_mask=prompt_attention_mask,
713+
use_mask_in_transformer=use_mask_in_transformer,
710714
return_dict=False,
711715
)[0]
712716

@@ -717,6 +721,7 @@ def __call__(
717721
timestep=current_timestep,
718722
encoder_hidden_states=negative_prompt_embeds,
719723
attention_mask=negative_prompt_attention_mask,
724+
use_mask_in_transformer=use_mask_in_transformer,
720725
return_dict=False,
721726
)[0]
722727
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

0 commit comments

Comments
 (0)