Skip to content

Commit 07c670e

Browse files
committed
move reshape logic
1 parent 7c4eced commit 07c670e

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,21 @@ def forward(
795795
batch_size = hidden_states.shape[0]
796796
hidden_states_type = hidden_states.dtype
797797

798+
if hidden_states.shape[-2] != hidden_states.shape[-1]:
799+
B, C, H, W = hidden_states.shape
800+
patch_size = self.config.patch_size
801+
pH, pW = H // patch_size, W // patch_size
802+
out = torch.zeros(
803+
(B, C, self.max_seq, patch_size * patch_size),
804+
dtype=hidden_states.dtype,
805+
device=hidden_states.device,
806+
)
807+
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
808+
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
809+
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
810+
out[:, :, 0 : pH * pW] = hidden_states
811+
hidden_states = out
812+
798813
# 0. time
799814
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
800815
timesteps = self.t_embedder(timesteps, hidden_states_type)

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -666,21 +666,6 @@ def __call__(
666666
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
667667
timestep = t.expand(latent_model_input.shape[0])
668668

669-
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
670-
B, C, H, W = latent_model_input.shape
671-
patch_size = self.transformer.config.patch_size
672-
pH, pW = H // patch_size, W // patch_size
673-
out = torch.zeros(
674-
(B, C, self.transformer.max_seq, patch_size * patch_size),
675-
dtype=latent_model_input.dtype,
676-
device=latent_model_input.device,
677-
)
678-
latent_model_input = latent_model_input.reshape(B, C, pH, patch_size, pW, patch_size)
679-
latent_model_input = latent_model_input.permute(0, 1, 2, 4, 3, 5)
680-
latent_model_input = latent_model_input.reshape(B, C, pH * pW, patch_size * patch_size)
681-
out[:, :, 0 : pH * pW] = latent_model_input
682-
latent_model_input = out
683-
684669
noise_pred = self.transformer(
685670
hidden_states=latent_model_input,
686671
timesteps=timestep,

0 commit comments

Comments
 (0)