Skip to content

Commit a330fe0

Browse files
committed
update
1 parent 16c955c commit a330fe0

File tree

1 file changed

+54
-15
lines changed

1 file changed

+54
-15
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass
12
from typing import Any, Dict, List, Optional, Tuple, Union
23

34
import torch
@@ -6,9 +7,8 @@
67

78
from ...configuration_utils import ConfigMixin, register_to_config
89
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
9-
from ...models.modeling_outputs import Transformer2DModelOutput
1010
from ...models.modeling_utils import ModelMixin
11-
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
11+
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
1212
from ...utils.torch_utils import maybe_allow_in_graph
1313
from ..attention import Attention
1414
from ..embeddings import TimestepEmbedding, Timesteps
@@ -17,6 +17,29 @@
1717
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1818

1919

20+
@dataclass
21+
class HiDreamImageModelOutput(BaseOutput):
22+
sample: torch.Tensor
23+
double_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
24+
single_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
25+
26+
27+
class AddAuxiliaryLoss(torch.autograd.Function):
28+
@staticmethod
29+
def forward(ctx, x, loss):
30+
assert loss.numel() == 1
31+
ctx.dtype = loss.dtype
32+
ctx.required_aux_loss = loss.requires_grad
33+
return x
34+
35+
@staticmethod
36+
def backward(ctx, grad_output):
37+
grad_loss = None
38+
if ctx.required_aux_loss:
39+
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
40+
return grad_output, grad_loss
41+
42+
2043
class HiDreamImageFeedForwardSwiGLU(nn.Module):
2144
def __init__(
2245
self,
@@ -332,7 +355,6 @@ def forward(self, hidden_states):
332355
else:
333356
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
334357
ce = mask_ce.float().mean(0)
335-
336358
Pi = scores_for_aux.mean(0)
337359
fi = ce * self.n_routed_experts
338360
aux_loss = (Pi * fi).sum() * self.alpha
@@ -379,11 +401,11 @@ def forward(self, x):
379401
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
380402
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
381403
y = y.view(*orig_shape).to(dtype=wtype)
382-
# y = AddAuxiliaryLoss.apply(y, aux_loss)
404+
y = AddAuxiliaryLoss.apply(y, aux_loss)
383405
else:
384406
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
385407
y = y + self.shared_experts(identity)
386-
return y
408+
return y, aux_loss
387409

388410
@torch.no_grad()
389411
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
@@ -481,9 +503,10 @@ def forward(
481503
# 2. Feed-forward
482504
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
483505
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
484-
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
506+
ff_output_i, aux_loss = self.ff_i(norm_hidden_states.to(dtype=wtype))
507+
ff_output_i = gate_mlp_i * ff_output_i
485508
hidden_states = ff_output_i + hidden_states
486-
return hidden_states
509+
return hidden_states, aux_loss
487510

488511

489512
@maybe_allow_in_graph
@@ -573,11 +596,12 @@ def forward(
573596
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
574597
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
575598

576-
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
599+
ff_output_i, aux_loss = self.ff_i(norm_hidden_states)
600+
ff_output_i = gate_mlp_i * ff_output_i
577601
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
578602
hidden_states = ff_output_i + hidden_states
579603
encoder_hidden_states = ff_output_t + encoder_hidden_states
580-
return hidden_states, encoder_hidden_states
604+
return hidden_states, encoder_hidden_states, aux_loss
581605

582606

583607
class HiDreamBlock(nn.Module):
@@ -785,6 +809,7 @@ def forward(
785809
hidden_states_masks: Optional[torch.Tensor] = None,
786810
attention_kwargs: Optional[Dict[str, Any]] = None,
787811
return_dict: bool = True,
812+
return_auxiliary_loss: bool = False,
788813
**kwargs,
789814
):
790815
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
@@ -866,15 +891,19 @@ def forward(
866891

867892
# 2. Blocks
868893
block_id = 0
894+
double_blocks_aux_losses = []
895+
single_blocks_aux_losses = []
896+
869897
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
870898
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
899+
871900
for bid, block in enumerate(self.double_stream_blocks):
872901
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
873902
cur_encoder_hidden_states = torch.cat(
874903
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
875904
)
876905
if torch.is_grad_enabled() and self.gradient_checkpointing:
877-
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
906+
hidden_states, initial_encoder_hidden_states, aux_loss = self._gradient_checkpointing_func(
878907
block,
879908
hidden_states,
880909
hidden_states_masks,
@@ -883,14 +912,15 @@ def forward(
883912
image_rotary_emb,
884913
)
885914
else:
886-
hidden_states, initial_encoder_hidden_states = block(
915+
hidden_states, initial_encoder_hidden_states, aux_loss = block(
887916
hidden_states=hidden_states,
888917
hidden_states_masks=hidden_states_masks,
889918
encoder_hidden_states=cur_encoder_hidden_states,
890919
temb=temb,
891920
image_rotary_emb=image_rotary_emb,
892921
)
893922
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
923+
double_blocks_aux_losses.append(aux_loss)
894924
block_id += 1
895925

896926
image_tokens_seq_len = hidden_states.shape[1]
@@ -908,7 +938,7 @@ def forward(
908938
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
909939
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
910940
if torch.is_grad_enabled() and self.gradient_checkpointing:
911-
hidden_states = self._gradient_checkpointing_func(
941+
hidden_states, aux_loss = self._gradient_checkpointing_func(
912942
block,
913943
hidden_states,
914944
hidden_states_masks,
@@ -917,14 +947,15 @@ def forward(
917947
image_rotary_emb,
918948
)
919949
else:
920-
hidden_states = block(
950+
hidden_states, aux_loss = block(
921951
hidden_states=hidden_states,
922952
hidden_states_masks=hidden_states_masks,
923953
encoder_hidden_states=None,
924954
temb=temb,
925955
image_rotary_emb=image_rotary_emb,
926956
)
927957
hidden_states = hidden_states[:, :hidden_states_seq_len]
958+
single_blocks_aux_losses.append(aux_loss)
928959
block_id += 1
929960

930961
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
@@ -938,5 +969,13 @@ def forward(
938969
unscale_lora_layers(self, lora_scale)
939970

940971
if not return_dict:
941-
return (output,)
942-
return Transformer2DModelOutput(sample=output)
972+
return_values = (output,)
973+
if return_auxiliary_loss:
974+
return_values += (double_blocks_aux_losses, single_blocks_aux_losses)
975+
return return_values
976+
977+
return HiDreamImageModelOutput(
978+
sample=output,
979+
double_blocks_auxiliary_loss=double_blocks_aux_losses,
980+
single_blocks_auxiliary_loss=single_blocks_aux_losses,
981+
)

0 commit comments

Comments
 (0)