Skip to content

Commit a40c95f

Browse files
committed
-einops
1 parent 3b5e03b commit a40c95f

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
from typing import Any, Dict, List, Optional, Tuple
33

4-
import einops
54
import torch
65
import torch.nn as nn
76
import torch.nn.functional as F
@@ -756,21 +755,26 @@ def expand_timesteps(self, timesteps, batch_size, device):
756755

757756
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
758757
if is_training:
759-
x = einops.rearrange(
760-
x, "B S (p1 p2 C) -> B C S (p1 p2)", p1=self.config.patch_size, p2=self.config.patch_size
758+
B, S, F = x.shape
759+
C = F // (self.config.patch_size * self.config.patch_size)
760+
x = (
761+
x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
762+
.permute(0, 4, 1, 2, 3)
763+
.reshape(B, C, S, self.config.patch_size * self.config.patch_size)
761764
)
762765
else:
763766
x_arr = []
767+
p1 = self.config.patch_size
768+
p2 = self.config.patch_size
764769
for i, img_size in enumerate(img_sizes):
765770
pH, pW = img_size
766-
x_arr.append(
767-
einops.rearrange(
768-
x[i, : pH * pW].reshape(1, pH, pW, -1),
769-
"B H W (p1 p2 C) -> B C (H p1) (W p2)",
770-
p1=self.config.patch_size,
771-
p2=self.config.patch_size,
772-
)
773-
)
771+
t = x[i, : pH * pW].reshape(1, pH, pW, -1)
772+
F_token = t.shape[-1]
773+
C = F_token // (p1 * p2)
774+
t = t.reshape(1, pH, pW, p1, p2, C)
775+
t = t.permute(0, 5, 1, 3, 2, 4)
776+
t = t.reshape(1, C, pH * p1, pW * p2)
777+
x_arr.append(t)
774778
x = torch.cat(x_arr, dim=0)
775779
return x
776780

@@ -789,12 +793,14 @@ def patchify(self, x, max_seq, img_sizes=None):
789793
if img_sizes is not None:
790794
for i, img_size in enumerate(img_sizes):
791795
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
792-
x = einops.rearrange(x, "B C S p -> B S (p C)", p=pz2, C=C)
796+
B, C, S, _ = x.shape
797+
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
793798
elif isinstance(x, torch.Tensor):
794-
pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
795-
x = einops.rearrange(
796-
x, "B C (H p1) (W p2) -> B (H W) (p1 p2 C)", p1=self.config.patch_size, p2=self.config.patch_size, C=C
797-
)
799+
B, C, Hp1, Wp2 = x.shape
800+
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
801+
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
802+
x = x.permute(0, 2, 4, 3, 5, 1)
803+
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
798804
img_sizes = [[pH, pW]] * B
799805
x_masks = None
800806
else:

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
from typing import Any, Callable, Dict, List, Optional, Union
44

5-
import einops
65
import torch
76
from transformers import (
87
CLIPTextModelWithProjection,
@@ -679,9 +678,9 @@ def __call__(
679678
dtype=latent_model_input.dtype,
680679
device=latent_model_input.device,
681680
)
682-
latent_model_input = einops.rearrange(
683-
latent_model_input, "B C (H p1) (W p2) -> B C (H W) (p1 p2)", p1=patch_size, p2=patch_size
684-
)
681+
latent_model_input = latent_model_input.reshape(B, C, pH, patch_size, pW, patch_size)
682+
latent_model_input = latent_model_input.permute(0, 1, 2, 4, 3, 5)
683+
latent_model_input = latent_model_input.reshape(B, C, pH * pW, patch_size * patch_size)
685684
out[:, :, 0 : pH * pW] = latent_model_input
686685
latent_model_input = out
687686

0 commit comments

Comments
 (0)