Skip to content

Commit b8aa38d

Browse files
committed
fix -einops
1 parent f94d68e commit b8aa38d

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,11 @@ def forward(
849849
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
850850
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
851851
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
852-
# repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
853-
img_ids = img_ids.reshape(img_ids.shape[0], img_ids.shape[1] * img_ids.shape[2]).unsqueeze(0)
852+
img_ids = (
853+
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
854+
.unsqueeze(0)
855+
.repeat(batch_size, 1, 1)
856+
)
854857
hidden_states = self.x_embedder(hidden_states)
855858

856859
T5_encoder_hidden_states = encoder_hidden_states[0]

0 commit comments

Comments
 (0)