1818
1919from ...configuration_utils import ConfigMixin , register_to_config
2020from ...loaders .single_file_model import FromOriginalModelMixin
21+ from ...utils import deprecate
2122from ...utils .accelerate_utils import apply_forward_hook
2223from ..attention_processor import (
2324 ADDED_KV_ATTENTION_PROCESSORS ,
@@ -245,6 +246,18 @@ def set_default_attn_processor(self):
245246
246247 self .set_attn_processor (processor )
247248
249+ def _encode (self , x : torch .Tensor ) -> torch .Tensor :
250+ batch_size , num_channels , height , width = x .shape
251+
252+ if self .use_tiling and (width > self .tile_sample_min_size or height > self .tile_sample_min_size ):
253+ return self ._tiled_encode (x )
254+
255+ enc = self .encoder (x )
256+ if self .quant_conv is not None :
257+ enc = self .quant_conv (enc )
258+
259+ return enc
260+
248261 @apply_forward_hook
249262 def encode (
250263 self , x : torch .Tensor , return_dict : bool = True
@@ -261,21 +274,13 @@ def encode(
261274 The latent representations of the encoded images. If `return_dict` is True, a
262275 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
263276 """
264- if self .use_tiling and (x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size ):
265- return self .tiled_encode (x , return_dict = return_dict )
266-
267277 if self .use_slicing and x .shape [0 ] > 1 :
268- encoded_slices = [self .encoder (x_slice ) for x_slice in x .split (1 )]
278+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
269279 h = torch .cat (encoded_slices )
270280 else :
271- h = self .encoder (x )
272-
273- if self .quant_conv is not None :
274- moments = self .quant_conv (h )
275- else :
276- moments = h
281+ h = self ._encode (x )
277282
278- posterior = DiagonalGaussianDistribution (moments )
283+ posterior = DiagonalGaussianDistribution (h )
279284
280285 if not return_dict :
281286 return (posterior ,)
@@ -337,6 +342,54 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
337342 b [:, :, :, x ] = a [:, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, x ] * (x / blend_extent )
338343 return b
339344
345+ def _tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
346+ r"""Encode a batch of images using a tiled encoder.
347+
348+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
349+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
350+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
351+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
352+ output, but they should be much less noticeable.
353+
354+ Args:
355+ x (`torch.Tensor`): Input batch of images.
356+
357+ Returns:
358+ `torch.Tensor`:
359+ The latent representation of the encoded videos.
360+ """
361+
362+ overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
363+ blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
364+ row_limit = self .tile_latent_min_size - blend_extent
365+
366+ # Split the image into 512x512 tiles and encode them separately.
367+ rows = []
368+ for i in range (0 , x .shape [2 ], overlap_size ):
369+ row = []
370+ for j in range (0 , x .shape [3 ], overlap_size ):
371+ tile = x [:, :, i : i + self .tile_sample_min_size , j : j + self .tile_sample_min_size ]
372+ tile = self .encoder (tile )
373+ if self .config .use_quant_conv :
374+ tile = self .quant_conv (tile )
375+ row .append (tile )
376+ rows .append (row )
377+ result_rows = []
378+ for i , row in enumerate (rows ):
379+ result_row = []
380+ for j , tile in enumerate (row ):
381+ # blend the above tile and the left tile
382+ # to the current tile and add the current tile to the result row
383+ if i > 0 :
384+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent )
385+ if j > 0 :
386+ tile = self .blend_h (row [j - 1 ], tile , blend_extent )
387+ result_row .append (tile [:, :, :row_limit , :row_limit ])
388+ result_rows .append (torch .cat (result_row , dim = 3 ))
389+
390+ enc = torch .cat (result_rows , dim = 2 )
391+ return enc
392+
340393 def tiled_encode (self , x : torch .Tensor , return_dict : bool = True ) -> AutoencoderKLOutput :
341394 r"""Encode a batch of images using a tiled encoder.
342395
@@ -356,6 +409,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder
356409 If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
357410 `tuple` is returned.
358411 """
412+ deprecation_message = (
413+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
414+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
415+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
416+ )
417+ deprecate ("tiled_encode" , "1.0.0" , deprecation_message , standard_warn = False )
418+
359419 overlap_size = int (self .tile_sample_min_size * (1 - self .tile_overlap_factor ))
360420 blend_extent = int (self .tile_latent_min_size * self .tile_overlap_factor )
361421 row_limit = self .tile_latent_min_size - blend_extent
0 commit comments