Skip to content

Commit 58edc80

Browse files
committed
refactor tiling; remove einops dependency
1 parent 96ccfb5 commit 58edc80

File tree

5 files changed

+214
-143
lines changed

5 files changed

+214
-143
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_magvit.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,7 @@ def __init__(
568568
super().__init__()
569569

570570
# 1. Input convolution
571-
self.conv_in = EasyAnimateCausalConv3d(
572-
in_channels,
573-
block_out_channels[-1],
574-
kernel_size=3,
575-
)
571+
self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3)
576572

577573
# 2. Middle block
578574
self.mid_block = EasyAnimateMidBlock3d(
@@ -734,21 +730,36 @@ def __init__(
734730
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
735731
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
736732

737-
# Assign mini-batch sizes for encoder and decoder
738-
self.mini_batch_encoder = 4
739-
self.mini_batch_decoder = 1
733+
self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
734+
self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2)
740735

741-
# Initialize tiling and slicing flags
736+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
737+
# to perform decoding of a single video latent at a time.
742738
self.use_slicing = False
739+
740+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
741+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
742+
# intermediate tiles together, the memory requirement can be lowered.
743743
self.use_tiling = False
744744

745-
# Set parameters for tiling if used
746-
tile_overlap_factor = 0.25
747-
self.tile_sample_min_size = 384
748-
self.tile_overlap_factor = tile_overlap_factor
749-
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(block_out_channels) - 1)))
750-
# Assign the scaling factor for latent space
751-
self.scaling_factor = scaling_factor
745+
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
746+
# at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
747+
self.use_framewise_encoding = False
748+
self.use_framewise_decoding = False
749+
750+
# Assign mini-batch sizes for encoder and decoder
751+
self.num_sample_frames_batch_size = 4
752+
self.num_latent_frames_batch_size = 1
753+
754+
# The minimal tile height and width for spatial tiling to be used
755+
self.tile_sample_min_height = 512
756+
self.tile_sample_min_width = 512
757+
self.tile_sample_min_num_frames = 4
758+
759+
# The minimal distance between two spatial tiles
760+
self.tile_sample_stride_height = 448
761+
self.tile_sample_stride_width = 448
762+
self.tile_sample_stride_num_frames = 8
752763

753764
def _clear_conv_cache(self):
754765
# Clear cache for convolutional layers if needed
@@ -760,13 +771,39 @@ def _clear_conv_cache(self):
760771

761772
def enable_tiling(
762773
self,
774+
tile_sample_min_height: Optional[int] = None,
775+
tile_sample_min_width: Optional[int] = None,
776+
tile_sample_min_num_frames: Optional[int] = None,
777+
tile_sample_stride_height: Optional[float] = None,
778+
tile_sample_stride_width: Optional[float] = None,
779+
tile_sample_stride_num_frames: Optional[float] = None,
763780
) -> None:
764781
r"""
765782
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
766783
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
767784
processing larger images.
785+
786+
Args:
787+
tile_sample_min_height (`int`, *optional*):
788+
The minimum height required for a sample to be separated into tiles across the height dimension.
789+
tile_sample_min_width (`int`, *optional*):
790+
The minimum width required for a sample to be separated into tiles across the width dimension.
791+
tile_sample_stride_height (`int`, *optional*):
792+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
793+
no tiling artifacts produced across the height dimension.
794+
tile_sample_stride_width (`int`, *optional*):
795+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
796+
artifacts produced across the width dimension.
768797
"""
769798
self.use_tiling = True
799+
self.use_framewise_decoding = True
800+
self.use_framewise_encoding = True
801+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
802+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
803+
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
804+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
805+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
806+
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
770807

771808
def disable_tiling(self) -> None:
772809
r"""
@@ -805,14 +842,13 @@ def _encode(
805842
The latent representations of the encoded images. If `return_dict` is True, a
806843
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
807844
"""
808-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
809-
x = self.tiled_encode(x, return_dict=return_dict)
810-
return x
845+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width):
846+
return self.tiled_encode(x, return_dict=return_dict)
811847

812-
first_frames = self.encoder(x[:, :, 0:1, :, :])
848+
first_frames = self.encoder(x[:, :, :1, :, :])
813849
h = [first_frames]
814-
for i in range(1, x.shape[2], self.mini_batch_encoder):
815-
next_frames = self.encoder(x[:, :, i : i + self.mini_batch_encoder, :, :])
850+
for i in range(1, x.shape[2], self.num_sample_frames_batch_size):
851+
next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :])
816852
h.append(next_frames)
817853
h = torch.cat(h, dim=2)
818854
moments = self.quant_conv(h)
@@ -849,18 +885,22 @@ def encode(
849885
return AutoencoderKLOutput(latent_dist=posterior)
850886

851887
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
852-
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
888+
batch_size, num_channels, num_frames, height, width = z.shape
889+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
890+
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
891+
892+
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
853893
return self.tiled_decode(z, return_dict=return_dict)
854894

855895
z = self.post_quant_conv(z)
856896

857897
# Process the first frame and save the result
858-
first_frames = self.decoder(z[:, :, 0:1, :, :])
898+
first_frames = self.decoder(z[:, :, :1, :, :])
859899
# Initialize the list to store the processed frames, starting with the first frame
860900
dec = [first_frames]
861901
# Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
862-
for i in range(1, z.shape[2], self.mini_batch_decoder):
863-
next_frames = self.decoder(z[:, :, i : i + self.mini_batch_decoder, :, :])
902+
for i in range(1, z.shape[2], self.num_latent_frames_batch_size):
903+
next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :])
864904
dec.append(next_frames)
865905
# Concatenate all processed frames along the channel dimension
866906
dec = torch.cat(dec, dim=2)
@@ -913,27 +953,35 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
913953
return b
914954

915955
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
916-
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
917-
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
918-
row_limit = self.tile_latent_min_size - blend_extent
956+
batch_size, num_channels, num_frames, height, width = x.shape
957+
latent_height = height // self.spatial_compression_ratio
958+
latent_width = width // self.spatial_compression_ratio
959+
960+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
961+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
962+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
963+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
964+
965+
blend_height = tile_latent_min_height - tile_latent_stride_height
966+
blend_width = tile_latent_min_width - tile_latent_stride_width
919967

920968
# Split the image into 512x512 tiles and encode them separately.
921969
rows = []
922-
for i in range(0, x.shape[3], overlap_size):
970+
for i in range(0, height, self.tile_sample_stride_height):
923971
row = []
924-
for j in range(0, x.shape[4], overlap_size):
972+
for j in range(0, width, self.tile_sample_stride_width):
925973
tile = x[
926974
:,
927975
:,
928976
:,
929-
i : i + self.tile_sample_min_size,
930-
j : j + self.tile_sample_min_size,
977+
i : i + self.tile_sample_min_height,
978+
j : j + self.tile_sample_min_width,
931979
]
932980

933981
first_frames = self.encoder(tile[:, :, 0:1, :, :])
934982
tile_h = [first_frames]
935-
for frame_index in range(1, tile.shape[2], self.mini_batch_encoder):
936-
next_frames = self.encoder(tile[:, :, frame_index : frame_index + self.mini_batch_encoder, :, :])
983+
for k in range(1, num_frames, self.num_sample_frames_batch_size):
984+
next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :])
937985
tile_h.append(next_frames)
938986
tile = torch.cat(tile_h, dim=2)
939987
tile = self.quant_conv(tile)
@@ -947,42 +995,50 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
947995
# blend the above tile and the left tile
948996
# to the current tile and add the current tile to the result row
949997
if i > 0:
950-
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
998+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
951999
if j > 0:
952-
tile = self.blend_h(row[j - 1], tile, blend_extent)
953-
result_row.append(tile[:, :, :, :row_limit, :row_limit])
1000+
tile = self.blend_h(row[j - 1], tile, blend_width)
1001+
result_row.append(tile[:, :, :, :latent_height, :latent_width])
9541002
result_rows.append(torch.cat(result_row, dim=4))
9551003

956-
moments = torch.cat(result_rows, dim=3)
1004+
moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
9571005
return moments
9581006

9591007
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
960-
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
961-
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
962-
row_limit = self.tile_sample_min_size - blend_extent
1008+
batch_size, num_channels, num_frames, height, width = z.shape
1009+
sample_height = height * self.spatial_compression_ratio
1010+
sample_width = width * self.spatial_compression_ratio
1011+
1012+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1013+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1014+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1015+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1016+
1017+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1018+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
9631019

9641020
# Split z into overlapping 64x64 tiles and decode them separately.
9651021
# The tiles have an overlap to avoid seams between tiles.
9661022
rows = []
967-
for i in range(0, z.shape[3], overlap_size):
1023+
for i in range(0, height, tile_latent_stride_height):
9681024
row = []
969-
for j in range(0, z.shape[4], overlap_size):
1025+
for j in range(0, width, tile_latent_stride_width):
9701026
tile = z[
9711027
:,
9721028
:,
9731029
:,
974-
i : i + self.tile_latent_min_size,
975-
j : j + self.tile_latent_min_size,
1030+
i : i + tile_latent_min_height,
1031+
j : j + tile_latent_min_width,
9761032
]
9771033
tile = self.post_quant_conv(tile)
9781034

9791035
# Process the first frame and save the result
980-
first_frames = self.decoder(tile[:, :, 0:1, :, :])
1036+
first_frames = self.decoder(tile[:, :, :1, :, :])
9811037
# Initialize the list to store the processed frames, starting with the first frame
9821038
tile_dec = [first_frames]
9831039
# Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
984-
for frame_index in range(1, tile.shape[2], self.mini_batch_decoder):
985-
next_frames = self.decoder(tile[:, :, frame_index : frame_index + self.mini_batch_decoder, :, :])
1040+
for k in range(1, num_frames, self.num_latent_frames_batch_size):
1041+
next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :])
9861042
tile_dec.append(next_frames)
9871043
# Concatenate all processed frames along the channel dimension
9881044
decoded = torch.cat(tile_dec, dim=2)
@@ -996,13 +1052,13 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
9961052
# blend the above tile and the left tile
9971053
# to the current tile and add the current tile to the result row
9981054
if i > 0:
999-
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
1055+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
10001056
if j > 0:
1001-
tile = self.blend_h(row[j - 1], tile, blend_extent)
1002-
result_row.append(tile[:, :, :, :row_limit, :row_limit])
1057+
tile = self.blend_h(row[j - 1], tile, blend_width)
1058+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
10031059
result_rows.append(torch.cat(result_row, dim=4))
10041060

1005-
dec = torch.cat(result_rows, dim=3)
1061+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
10061062

10071063
if not return_dict:
10081064
return (dec,)

src/diffusers/models/transformers/transformer_easyanimate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional, Tuple, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn.functional as F
2020
from torch import nn
21-
from typing import Any, Dict, List, Optional, Tuple, Union
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ...utils import logging
@@ -74,7 +73,7 @@ def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
7473

7574
self.patch_size = patch_size
7675
self.rope_dim = rope_dim
77-
76+
7877
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
7978
tw = tgt_width
8079
th = tgt_height

src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
3030
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
31-
from ...models.embeddings import get_2d_rotary_pos_embed, get_3d_rotary_pos_embed
3231
from ...pipelines.pipeline_utils import DiffusionPipeline
3332
from ...schedulers import FlowMatchEulerDiscreteScheduler
3433
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -236,8 +235,13 @@ def __init__(
236235
transformer=transformer,
237236
scheduler=scheduler,
238237
)
239-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
240-
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
238+
self.vae_spatial_compression_ratio = (
239+
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
240+
)
241+
self.vae_temporal_compression_ratio = (
242+
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
243+
)
244+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
241245

242246
def encode_prompt(
243247
self,
@@ -607,18 +611,18 @@ def check_inputs(
607611
f" {negative_prompt_embeds_2.shape}."
608612
)
609613

610-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
611614
def prepare_latents(
612615
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
613616
):
614-
mini_batch_encoder = self.vae.mini_batch_encoder
615-
mini_batch_decoder = self.vae.mini_batch_decoder
617+
if latents is not None:
618+
return latents.to(device=device, dtype=dtype)
619+
616620
shape = (
617621
batch_size,
618622
num_channels_latents,
619-
int((num_frames - 1) // mini_batch_encoder * mini_batch_decoder + 1) if num_frames != 1 else 1,
620-
height // self.vae_scale_factor,
621-
width // self.vae_scale_factor,
623+
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
624+
height // self.vae_spatial_compression_ratio,
625+
width // self.vae_spatial_compression_ratio,
622626
)
623627

624628
if isinstance(generator, list) and len(generator) != batch_size:
@@ -627,21 +631,12 @@ def prepare_latents(
627631
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
628632
)
629633

630-
if latents is None:
631-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
632-
else:
633-
latents = latents.to(device)
634-
634+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
635635
# scale the initial noise by the standard deviation required by the scheduler
636636
if hasattr(self.scheduler, "init_noise_sigma"):
637637
latents = latents * self.scheduler.init_noise_sigma
638638
return latents
639639

640-
def decode_latents(self, latents):
641-
latents = 1 / self.vae.config.scaling_factor * latents
642-
video = self.vae.decode(latents).sample
643-
return video
644-
645640
@property
646641
def guidance_scale(self):
647642
return self._guidance_scale
@@ -953,9 +948,9 @@ def __call__(
953948
if XLA_AVAILABLE:
954949
xm.mark_step()
955950

956-
# Convert to tensor
957951
if not output_type == "latent":
958-
video = self.decode_latents(latents)
952+
latents = 1 / self.vae.config.scaling_factor * latents
953+
video = self.vae.decode(latents, return_dict=False)[0]
959954
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
960955
else:
961956
video = latents

0 commit comments

Comments
 (0)