Skip to content

Commit 419c99d

Browse files
committed
remove more rearrange
1 parent cceae4a commit 419c99d

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,27 +142,40 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: b
142142
self.add_temporal_upsample = add_temporal_upsample
143143
self.repeats = factor * out_channels // in_channels
144144

145+
@staticmethod
146+
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
147+
"""
148+
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
149+
150+
Args:
151+
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
152+
r1: temporal upsampling factor
153+
r2: height upsampling factor
154+
r3: width upsampling factor
155+
"""
156+
b, packed_c, f, h, w = tensor.shape
157+
factor = r1 * r2 * r3
158+
c = packed_c // factor
159+
160+
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
161+
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
162+
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
163+
145164
def forward(self, x: torch.Tensor):
146165
r1 = 2 if self.add_temporal_upsample else 1
147166
h = self.conv(x)
148167
if self.add_temporal_upsample:
149-
h = rearrange(h, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
168+
h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
150169
h = h[:, : h.shape[1] // 2]
151170

152171
# shortcut computation
153-
shortcut = rearrange(x, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
172+
shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
154173
shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
155174

156175
else:
157-
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
176+
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
158177
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
159-
shortcut = rearrange(
160-
shortcut,
161-
"b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)",
162-
r1=r1,
163-
r2=2,
164-
r3=2,
165-
)
178+
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
166179
return h + shortcut
167180

168181

@@ -177,20 +190,40 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample:
177190
self.add_temporal_downsample = add_temporal_downsample
178191
self.group_size = factor * in_channels // out_channels
179192

193+
194+
@staticmethod
195+
def _dcae_downsample_rearrange(self, tensor, r1=1, r2=2, r3=2):
196+
"""
197+
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
198+
199+
This packs spatial/temporal dimensions into channels (opposite of upsample)
200+
"""
201+
b, c, packed_f, packed_h, packed_w = tensor.shape
202+
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
203+
204+
tensor = tensor.view(b, c, r1, f, r2, h, r3, w)
205+
tensor = tensor.permute(0, 2, 4, 6, 1, 3, 5, 7)
206+
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
207+
208+
180209
def forward(self, x: torch.Tensor):
181210
r1 = 2 if self.add_temporal_downsample else 1
182211
h = self.conv(x)
183212
if self.add_temporal_downsample:
184-
h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
213+
# h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
214+
h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
185215
h = torch.cat([h, h], dim=1)
186216

187217
# shortcut computation
188-
shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
218+
# shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
219+
shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
189220
B, C, T, H, W = shortcut.shape
190221
shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
191222
else:
192-
h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
193-
shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
223+
# h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
224+
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
225+
# shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
226+
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
194227
B, C, T, H, W = shortcut.shape
195228
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
196229

0 commit comments

Comments
 (0)