Skip to content

Commit 3d41281

Browse files
committed
make style
1 parent 4f59d56 commit 3d41281

File tree

2 files changed

+29
-40
lines changed

2 files changed

+29
-40
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
nn.SiLU(),
9191
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
9292
)
93-
93+
9494
@staticmethod
9595
def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
9696
hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
@@ -118,10 +118,10 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
118118

119119
hidden_states = self._pad_temporal_dim(hidden_states)
120120
hidden_states = self.conv2(hidden_states)
121-
121+
122122
hidden_states = self._pad_temporal_dim(hidden_states)
123123
hidden_states = self.conv3(hidden_states)
124-
124+
125125
hidden_states = self._pad_temporal_dim(hidden_states)
126126
hidden_states = self.conv4(hidden_states)
127127

@@ -200,7 +200,7 @@ def __init__(
200200

201201
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202202
batch_size = hidden_states.shape[0]
203-
203+
204204
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
205205

206206
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -213,7 +213,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
213213
if self.downsamplers is not None:
214214
for downsampler in self.downsamplers:
215215
hidden_states = downsampler(hidden_states)
216-
216+
217217
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
218218
return hidden_states
219219

@@ -282,7 +282,7 @@ def __init__(
282282

283283
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
284284
batch_size = hidden_states.shape[0]
285-
285+
286286
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
287287

288288
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -295,7 +295,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
295295
if self.upsamplers is not None:
296296
for upsampler in self.upsamplers:
297297
hidden_states = upsampler(hidden_states)
298-
298+
299299
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
300300
return hidden_states
301301

@@ -399,7 +399,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
399399

400400
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
401401
hidden_states = self.resnets[0](hidden_states, temb=None)
402-
402+
403403
hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
404404

405405
for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
@@ -532,15 +532,15 @@ def custom_forward(*inputs):
532532
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
533533
sample = self.conv_norm_out(sample)
534534
sample = self.conv_act(sample)
535-
535+
536536
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
537537
residual = sample
538538
sample = self.temp_conv_out(sample)
539539
sample = sample + residual
540-
540+
541541
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
542542
sample = self.conv_out(sample)
543-
543+
544544
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
545545
return sample
546546

@@ -674,15 +674,15 @@ def custom_forward(*inputs):
674674
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
675675
sample = self.conv_norm_out(sample)
676676
sample = self.conv_act(sample)
677-
677+
678678
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
679679
residual = sample
680680
sample = self.temp_conv_out(sample)
681681
sample = sample + residual
682682

683683
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
684684
sample = self.conv_out(sample)
685-
685+
686686
sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
687687
return sample
688688

@@ -804,7 +804,7 @@ def __init__(
804804
chunk_len = 24
805805
t_over = 8
806806
tile_overlap = (120, 80)
807-
807+
808808
self.latent_chunk_len = chunk_len // 4
809809
self.latent_t_over = t_over // 4
810810
self.kernel = (chunk_len, sample_size, sample_size) # (24, 256, 256)
@@ -817,7 +817,7 @@ def __init__(
817817
def _set_gradient_checkpointing(self, module, value=False):
818818
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
819819
module.gradient_checkpointing = value
820-
820+
821821
def enable_tiling(
822822
self,
823823
# tile_sample_min_height: Optional[int] = None,
@@ -876,17 +876,19 @@ def disable_slicing(self) -> None:
876876
decoding in one step.
877877
"""
878878
self.use_slicing = False
879-
879+
880880
def _encode(self, x: torch.Tensor) -> torch.Tensor:
881881
# TODO(aryan)
882882
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
883883
if self.use_tiling:
884884
return self.tiled_encode(x)
885-
885+
886886
raise NotImplementedError("Encoding without tiling has not been implemented yet.")
887-
887+
888888
@apply_forward_hook
889-
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
889+
def encode(
890+
self, x: torch.Tensor, return_dict: bool = True
891+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
890892
r"""
891893
Encode a batch of videos into latents.
892894
@@ -919,7 +921,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor:
919921
return self.tiled_decode(z)
920922

921923
raise NotImplementedError("Decoding without tiling has not been implemented yet.")
922-
924+
923925
@apply_forward_hook
924926
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
925927
"""
@@ -946,12 +948,10 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
946948
return (decoded,)
947949
return DecoderOutput(sample=decoded)
948950

949-
def tiled_encode(
950-
self, x: torch.Tensor
951-
) -> torch.Tensor:
951+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
952952
# TODO(aryan): parameterize this in enable_tiling
953953
local_batch_size = 1
954-
954+
955955
# TODO(aryan): rewrite to encode and tiled_encode
956956
KERNEL = self.kernel
957957
STRIDE = self.stride
@@ -972,9 +972,7 @@ def tiled_encode(
972972
device=x.device,
973973
dtype=x.dtype,
974974
)
975-
vae_batch_input = torch.zeros(
976-
(LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype
977-
)
975+
vae_batch_input = torch.zeros((LOCAL_BS, C, KERNEL[0], KERNEL[1], KERNEL[2]), device=x.device, dtype=x.dtype)
978976

979977
for i in range(out_n):
980978
for j in range(out_h):
@@ -1002,9 +1000,7 @@ def tiled_encode(
10021000
## flatten the batched out latent to videos and supress the overlapped parts
10031001
B, C, N, H, W = x.shape
10041002

1005-
out_video_cube = torch.zeros(
1006-
(B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype
1007-
)
1003+
out_video_cube = torch.zeros((B, OUT_C, N // 4, H // 8, W // 8), device=x.device, dtype=x.dtype)
10081004
OUT_KERNEL = KERNEL[0] // 4, KERNEL[1] // 8, KERNEL[2] // 8
10091005
OUT_STRIDE = STRIDE[0] // 4, STRIDE[1] // 8, STRIDE[2] // 8
10101006
OVERLAP = OUT_KERNEL[0] - OUT_STRIDE[0], OUT_KERNEL[1] - OUT_STRIDE[1], OUT_KERNEL[2] - OUT_STRIDE[2]
@@ -1030,9 +1026,7 @@ def tiled_encode(
10301026

10311027
return out_video_cube
10321028

1033-
def tiled_decode(
1034-
self, z: torch.Tensor
1035-
) -> torch.Tensor:
1029+
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
10361030
# TODO(aryan): parameterize this in enable_tiling
10371031
local_batch_size = 1
10381032

@@ -1092,9 +1086,7 @@ def tiled_decode(
10921086
num += 1
10931087
B, C, N, H, W = z.shape
10941088

1095-
out_video = torch.zeros(
1096-
(B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype
1097-
)
1089+
out_video = torch.zeros((B, OUT_C, N * 4, H * 8, W * 8), device=z.device, dtype=z.dtype)
10981090
OVERLAP = KERNEL[0] - STRIDE[0], KERNEL[1] - STRIDE[1], KERNEL[2] - STRIDE[2]
10991091
for i in range(out_n):
11001092
n_start, n_end = i * STRIDE[0], i * STRIDE[0] + KERNEL[0]

src/diffusers/models/normalization.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
from ..utils import is_torch_version
2424
from .activations import get_activation
25-
from .embeddings import (
26-
CombinedTimestepLabelEmbeddings,
27-
PixArtAlphaCombinedTimestepSizeEmbeddings
28-
)
25+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
2926

3027

3128
class AdaLayerNorm(nn.Module):

0 commit comments

Comments
 (0)