Skip to content

Commit 905c215

Browse files
committed
add
1 parent c26f42b commit 905c215

File tree

2 files changed

+189
-106
lines changed

2 files changed

+189
-106
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,8 @@ def forward(
731731
self,
732732
hidden_states: torch.Tensor,
733733
timesteps: torch.LongTensor = None,
734-
encoder_hidden_states: torch.Tensor = None,
734+
t5_encoder_hidden_states: torch.Tensor = None,
735+
llama3_encoder_hidden_states: torch.Tensor = None,
735736
pooled_embeds: torch.Tensor = None,
736737
img_sizes: Optional[List[Tuple[int, int]]] = None,
737738
img_ids: Optional[torch.Tensor] = None,
@@ -791,9 +792,7 @@ def forward(
791792
)
792793
hidden_states = self.x_embedder(hidden_states)
793794

794-
T5_encoder_hidden_states = encoder_hidden_states[0]
795-
encoder_hidden_states = encoder_hidden_states[-1]
796-
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
795+
encoder_hidden_states = [llama3_encoder_hidden_states[k] for k in self.llama_layers]
797796

798797
if self.caption_projection is not None:
799798
new_encoder_hidden_states = []
@@ -802,9 +801,9 @@ def forward(
802801
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
803802
new_encoder_hidden_states.append(enc_hidden_state)
804803
encoder_hidden_states = new_encoder_hidden_states
805-
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
806-
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
807-
encoder_hidden_states.append(T5_encoder_hidden_states)
804+
t5_encoder_hidden_states = self.caption_projection[-1](t5_encoder_hidden_states)
805+
t5_encoder_hidden_states = t5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
806+
encoder_hidden_states.append(t5_encoder_hidden_states)
808807

809808
txt_ids = torch.zeros(
810809
batch_size,

0 commit comments

Comments
 (0)