1+ from dataclasses import dataclass
12from typing import Any , Dict , List , Optional , Tuple , Union
23
34import torch
67
78from ...configuration_utils import ConfigMixin , register_to_config
89from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
9- from ...models .modeling_outputs import Transformer2DModelOutput
1010from ...models .modeling_utils import ModelMixin
11- from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
11+ from ...utils import USE_PEFT_BACKEND , BaseOutput , deprecate , logging , scale_lora_layers , unscale_lora_layers
1212from ...utils .torch_utils import maybe_allow_in_graph
1313from ..attention import Attention
1414from ..embeddings import TimestepEmbedding , Timesteps
1717logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
1818
1919
20+ @dataclass
21+ class HiDreamImageModelOutput (BaseOutput ):
22+ sample : torch .Tensor
23+ double_blocks_auxiliary_loss : Optional [Tuple [torch .Tensor , ...]] = None
24+ single_blocks_auxiliary_loss : Optional [Tuple [torch .Tensor , ...]] = None
25+
26+
27+ class AddAuxiliaryLoss (torch .autograd .Function ):
28+ @staticmethod
29+ def forward (ctx , x , loss ):
30+ assert loss .numel () == 1
31+ ctx .dtype = loss .dtype
32+ ctx .required_aux_loss = loss .requires_grad
33+ return x
34+
35+ @staticmethod
36+ def backward (ctx , grad_output ):
37+ grad_loss = None
38+ if ctx .required_aux_loss :
39+ grad_loss = torch .ones (1 , dtype = ctx .dtype , device = grad_output .device )
40+ return grad_output , grad_loss
41+
42+
2043class HiDreamImageFeedForwardSwiGLU (nn .Module ):
2144 def __init__ (
2245 self ,
@@ -332,7 +355,6 @@ def forward(self, hidden_states):
332355 else :
333356 mask_ce = F .one_hot (topk_idx_for_aux_loss .view (- 1 ), num_classes = self .n_routed_experts )
334357 ce = mask_ce .float ().mean (0 )
335-
336358 Pi = scores_for_aux .mean (0 )
337359 fi = ce * self .n_routed_experts
338360 aux_loss = (Pi * fi ).sum () * self .alpha
@@ -379,11 +401,11 @@ def forward(self, x):
379401 y [flat_topk_idx == i ] = expert (x [flat_topk_idx == i ]).to (dtype = wtype )
380402 y = (y .view (* topk_weight .shape , - 1 ) * topk_weight .unsqueeze (- 1 )).sum (dim = 1 )
381403 y = y .view (* orig_shape ).to (dtype = wtype )
382- # y = AddAuxiliaryLoss.apply(y, aux_loss)
404+ y = AddAuxiliaryLoss .apply (y , aux_loss )
383405 else :
384406 y = self .moe_infer (x , flat_topk_idx , topk_weight .view (- 1 , 1 )).view (* orig_shape )
385407 y = y + self .shared_experts (identity )
386- return y
408+ return y , aux_loss
387409
388410 @torch .no_grad ()
389411 def moe_infer (self , x , flat_expert_indices , flat_expert_weights ):
@@ -481,9 +503,10 @@ def forward(
481503 # 2. Feed-forward
482504 norm_hidden_states = self .norm3_i (hidden_states ).to (dtype = wtype )
483505 norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i ) + shift_mlp_i
484- ff_output_i = gate_mlp_i * self .ff_i (norm_hidden_states .to (dtype = wtype ))
506+ ff_output_i , aux_loss = self .ff_i (norm_hidden_states .to (dtype = wtype ))
507+ ff_output_i = gate_mlp_i * ff_output_i
485508 hidden_states = ff_output_i + hidden_states
486- return hidden_states
509+ return hidden_states , aux_loss
487510
488511
489512@maybe_allow_in_graph
@@ -573,11 +596,12 @@ def forward(
573596 norm_encoder_hidden_states = self .norm3_t (encoder_hidden_states ).to (dtype = wtype )
574597 norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t ) + shift_mlp_t
575598
576- ff_output_i = gate_mlp_i * self .ff_i (norm_hidden_states )
599+ ff_output_i , aux_loss = self .ff_i (norm_hidden_states )
600+ ff_output_i = gate_mlp_i * ff_output_i
577601 ff_output_t = gate_mlp_t * self .ff_t (norm_encoder_hidden_states )
578602 hidden_states = ff_output_i + hidden_states
579603 encoder_hidden_states = ff_output_t + encoder_hidden_states
580- return hidden_states , encoder_hidden_states
604+ return hidden_states , encoder_hidden_states , aux_loss
581605
582606
583607class HiDreamBlock (nn .Module ):
@@ -785,6 +809,7 @@ def forward(
785809 hidden_states_masks : Optional [torch .Tensor ] = None ,
786810 attention_kwargs : Optional [Dict [str , Any ]] = None ,
787811 return_dict : bool = True ,
812+ return_auxiliary_loss : bool = False ,
788813 ** kwargs ,
789814 ):
790815 encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
@@ -866,15 +891,19 @@ def forward(
866891
867892 # 2. Blocks
868893 block_id = 0
894+ double_blocks_aux_losses = []
895+ single_blocks_aux_losses = []
896+
869897 initial_encoder_hidden_states = torch .cat ([encoder_hidden_states [- 1 ], encoder_hidden_states [- 2 ]], dim = 1 )
870898 initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states .shape [1 ]
899+
871900 for bid , block in enumerate (self .double_stream_blocks ):
872901 cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
873902 cur_encoder_hidden_states = torch .cat (
874903 [initial_encoder_hidden_states , cur_llama31_encoder_hidden_states ], dim = 1
875904 )
876905 if torch .is_grad_enabled () and self .gradient_checkpointing :
877- hidden_states , initial_encoder_hidden_states = self ._gradient_checkpointing_func (
906+ hidden_states , initial_encoder_hidden_states , aux_loss = self ._gradient_checkpointing_func (
878907 block ,
879908 hidden_states ,
880909 hidden_states_masks ,
@@ -883,14 +912,15 @@ def forward(
883912 image_rotary_emb ,
884913 )
885914 else :
886- hidden_states , initial_encoder_hidden_states = block (
915+ hidden_states , initial_encoder_hidden_states , aux_loss = block (
887916 hidden_states = hidden_states ,
888917 hidden_states_masks = hidden_states_masks ,
889918 encoder_hidden_states = cur_encoder_hidden_states ,
890919 temb = temb ,
891920 image_rotary_emb = image_rotary_emb ,
892921 )
893922 initial_encoder_hidden_states = initial_encoder_hidden_states [:, :initial_encoder_hidden_states_seq_len ]
923+ double_blocks_aux_losses .append (aux_loss )
894924 block_id += 1
895925
896926 image_tokens_seq_len = hidden_states .shape [1 ]
@@ -908,7 +938,7 @@ def forward(
908938 cur_llama31_encoder_hidden_states = encoder_hidden_states [block_id ]
909939 hidden_states = torch .cat ([hidden_states , cur_llama31_encoder_hidden_states ], dim = 1 )
910940 if torch .is_grad_enabled () and self .gradient_checkpointing :
911- hidden_states = self ._gradient_checkpointing_func (
941+ hidden_states , aux_loss = self ._gradient_checkpointing_func (
912942 block ,
913943 hidden_states ,
914944 hidden_states_masks ,
@@ -917,14 +947,15 @@ def forward(
917947 image_rotary_emb ,
918948 )
919949 else :
920- hidden_states = block (
950+ hidden_states , aux_loss = block (
921951 hidden_states = hidden_states ,
922952 hidden_states_masks = hidden_states_masks ,
923953 encoder_hidden_states = None ,
924954 temb = temb ,
925955 image_rotary_emb = image_rotary_emb ,
926956 )
927957 hidden_states = hidden_states [:, :hidden_states_seq_len ]
958+ single_blocks_aux_losses .append (aux_loss )
928959 block_id += 1
929960
930961 hidden_states = hidden_states [:, :image_tokens_seq_len , ...]
@@ -938,5 +969,13 @@ def forward(
938969 unscale_lora_layers (self , lora_scale )
939970
940971 if not return_dict :
941- return (output ,)
942- return Transformer2DModelOutput (sample = output )
972+ return_values = (output ,)
973+ if return_auxiliary_loss :
974+ return_values += (double_blocks_aux_losses , single_blocks_aux_losses )
975+ return return_values
976+
977+ return HiDreamImageModelOutput (
978+ sample = output ,
979+ double_blocks_auxiliary_loss = double_blocks_aux_losses ,
980+ single_blocks_auxiliary_loss = single_blocks_aux_losses ,
981+ )
0 commit comments