@@ -731,7 +731,8 @@ def forward(
731731        self ,
732732        hidden_states : torch .Tensor ,
733733        timesteps : torch .LongTensor  =  None ,
734-         encoder_hidden_states : torch .Tensor  =  None ,
734+         t5_encoder_hidden_states : torch .Tensor  =  None ,
735+         llama3_encoder_hidden_states : torch .Tensor  =  None ,
735736        pooled_embeds : torch .Tensor  =  None ,
736737        img_sizes : Optional [List [Tuple [int , int ]]] =  None ,
737738        img_ids : Optional [torch .Tensor ] =  None ,
@@ -791,9 +792,7 @@ def forward(
791792            )
792793        hidden_states  =  self .x_embedder (hidden_states )
793794
794-         T5_encoder_hidden_states  =  encoder_hidden_states [0 ]
795-         encoder_hidden_states  =  encoder_hidden_states [- 1 ]
796-         encoder_hidden_states  =  [encoder_hidden_states [k ] for  k  in  self .llama_layers ]
795+         encoder_hidden_states  =  [llama3_encoder_hidden_states [k ] for  k  in  self .llama_layers ]
797796
798797        if  self .caption_projection  is  not None :
799798            new_encoder_hidden_states  =  []
@@ -802,9 +801,9 @@ def forward(
802801                enc_hidden_state  =  enc_hidden_state .view (batch_size , - 1 , hidden_states .shape [- 1 ])
803802                new_encoder_hidden_states .append (enc_hidden_state )
804803            encoder_hidden_states  =  new_encoder_hidden_states 
805-             T5_encoder_hidden_states  =  self .caption_projection [- 1 ](T5_encoder_hidden_states )
806-             T5_encoder_hidden_states  =  T5_encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
807-             encoder_hidden_states .append (T5_encoder_hidden_states )
804+             t5_encoder_hidden_states  =  self .caption_projection [- 1 ](t5_encoder_hidden_states )
805+             t5_encoder_hidden_states  =  t5_encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
806+             encoder_hidden_states .append (t5_encoder_hidden_states )
808807
809808        txt_ids  =  torch .zeros (
810809            batch_size ,
0 commit comments