Skip to content

Commit c201880

Browse files
committed
remove framewise encoding/decoding
1 parent 5391ceb commit c201880

File tree

3 files changed

+24
-40
lines changed

3 files changed

+24
-40
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -899,12 +899,12 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
899899
return self.tiled_encode(x)
900900

901901
if self.use_framewise_encoding:
902-
enc = []
903-
for i in range(0, num_frames, self.num_sample_frames_batch_size):
904-
x_intermediate = x[:, :, i : i + self.num_sample_frames_batch_size]
905-
x_intermediate = self.encoder(x_intermediate)
906-
enc.append(x_intermediate)
907-
enc = torch.cat(enc, dim=2)
902+
# TODO(aryan): requires investigation
903+
raise NotImplementedError(
904+
"Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to "
905+
"quality issues caused by splitting inference across frame dimension. If you believe this "
906+
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
907+
)
908908
else:
909909
enc = self.encoder(x)
910910

@@ -946,12 +946,12 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
946946
return self.tiled_decode(z, return_dict=return_dict)
947947

948948
if self.use_framewise_decoding:
949-
dec = []
950-
for i in range(0, num_frames, self.num_latent_frames_batch_size):
951-
z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
952-
z_intermediate = self.decoder(z_intermediate)
953-
dec.append(z_intermediate)
954-
dec = torch.cat(dec, dim=2)
949+
# TODO(aryan): requires investigation
950+
raise NotImplementedError(
951+
"Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to "
952+
"quality issues caused by splitting inference across frame dimension. If you believe this "
953+
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
954+
)
955955
else:
956956
dec = self.decoder(z)
957957

@@ -1031,17 +1031,12 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
10311031
row = []
10321032
for j in range(0, width, self.tile_sample_stride_width):
10331033
if self.use_framewise_encoding:
1034-
time = []
1035-
for k in range(0, num_frames, self.num_sample_frames_batch_size):
1036-
tile = x[
1037-
:,
1038-
:,
1039-
k : k + self.num_sample_frames_batch_size,
1040-
i : i + self.tile_sample_min_height,
1041-
j : j + self.tile_sample_min_width,
1042-
]
1043-
tile = self.encoder(tile)
1044-
time.append(tile)
1034+
# TODO(aryan): requires investigation
1035+
raise NotImplementedError(
1036+
"Frame-wise encoding has not been implemented for AutoencoderKLLTX, at the moment, due to "
1037+
"quality issues caused by splitting inference across frame dimension. If you believe this "
1038+
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1039+
)
10451040
else:
10461041
time = self.encoder(
10471042
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
@@ -1100,18 +1095,12 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
11001095
row = []
11011096
for j in range(0, width, tile_latent_stride_width):
11021097
if self.use_framewise_decoding:
1103-
time = []
1104-
for k in range(0, num_frames, self.num_latent_frames_batch_size):
1105-
tile = z[
1106-
:,
1107-
:,
1108-
k : k + self.num_latent_frames_batch_size,
1109-
i : i + tile_latent_min_height,
1110-
j : j + tile_latent_min_width,
1111-
]
1112-
tile = self.decoder(tile)
1113-
time.append(tile)
1114-
time = torch.cat(time, dim=2)
1098+
# TODO(aryan): requires investigation
1099+
raise NotImplementedError(
1100+
"Frame-wise decoding has not been implemented for AutoencoderKLLTX, at the moment, due to "
1101+
"quality issues caused by splitting inference across frame dimension. If you believe this "
1102+
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
1103+
)
11151104
else:
11161105
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
11171106

src/diffusers/models/normalization.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,6 @@ def forward(self, hidden_states):
543543

544544
return hidden_states
545545

546-
def extra_repr(self) -> str:
547-
return f"features={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}"
548-
549546

550547
class GlobalResponseNorm(nn.Module):
551548
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,6 @@ def __call__(
632632
)
633633
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
634634
self._num_timesteps = len(timesteps)
635-
print(self.scheduler.sigmas)
636-
print(len(self.scheduler.sigmas))
637635

638636
# 6. Prepare micro-conditions
639637
rope_interpolation_scale = (

0 commit comments

Comments
 (0)