Skip to content

Commit e3568d1

Browse files
authored
Freenoise change vae_batch_size to decode_chunk_size (#9110)
* update * update
1 parent f6df224 commit e3568d1

File tree

4 files changed

+30
-29
lines changed

4 files changed

+30
-29
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,15 @@ def prepare_ip_adapter_image_embeds(
396396

397397
return ip_adapter_image_embeds
398398

399-
def decode_latents(self, latents, vae_batch_size: int = 16):
399+
def decode_latents(self, latents, decode_chunk_size: int = 16):
400400
latents = 1 / self.vae.config.scaling_factor * latents
401401

402402
batch_size, channels, num_frames, height, width = latents.shape
403403
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
404404

405405
video = []
406-
for i in range(0, latents.shape[0], vae_batch_size):
407-
batch_latents = latents[i : i + vae_batch_size]
406+
for i in range(0, latents.shape[0], decode_chunk_size):
407+
batch_latents = latents[i : i + decode_chunk_size]
408408
batch_latents = self.vae.decode(batch_latents).sample
409409
video.append(batch_latents)
410410

@@ -582,7 +582,7 @@ def __call__(
582582
clip_skip: Optional[int] = None,
583583
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
584584
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
585-
vae_batch_size: int = 16,
585+
decode_chunk_size: int = 16,
586586
**kwargs,
587587
):
588588
r"""
@@ -651,7 +651,7 @@ def __call__(
651651
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
652652
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
653653
`._callback_tensor_inputs` attribute of your pipeline class.
654-
vae_batch_size (`int`, defaults to `16`):
654+
decode_chunk_size (`int`, defaults to `16`):
655655
The number of frames to decode at a time when calling `decode_latents` method.
656656
657657
Examples:
@@ -824,7 +824,7 @@ def __call__(
824824
if output_type == "latent":
825825
video = latents
826826
else:
827-
video_tensor = self.decode_latents(latents, vae_batch_size)
827+
video_tensor = self.decode_latents(latents, decode_chunk_size)
828828
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
829829

830830
# 10. Offload all models

src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,15 +435,15 @@ def prepare_ip_adapter_image_embeds(
435435
return ip_adapter_image_embeds
436436

437437
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
438-
def decode_latents(self, latents, vae_batch_size: int = 16):
438+
def decode_latents(self, latents, decode_chunk_size: int = 16):
439439
latents = 1 / self.vae.config.scaling_factor * latents
440440

441441
batch_size, channels, num_frames, height, width = latents.shape
442442
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
443443

444444
video = []
445-
for i in range(0, latents.shape[0], vae_batch_size):
446-
batch_latents = latents[i : i + vae_batch_size]
445+
for i in range(0, latents.shape[0], decode_chunk_size):
446+
batch_latents = latents[i : i + decode_chunk_size]
447447
batch_latents = self.vae.decode(batch_latents).sample
448448
video.append(batch_latents)
449449

@@ -728,7 +728,7 @@ def __call__(
728728
clip_skip: Optional[int] = None,
729729
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
730730
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
731-
vae_batch_size: int = 16,
731+
decode_chunk_size: int = 16,
732732
):
733733
r"""
734734
The call function to the pipeline for generation.
@@ -1064,7 +1064,7 @@ def __call__(
10641064
if output_type == "latent":
10651065
video = latents
10661066
else:
1067-
video_tensor = self.decode_latents(latents, vae_batch_size)
1067+
video_tensor = self.decode_latents(latents, decode_chunk_size)
10681068
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
10691069

10701070
# 10. Offload all models

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -500,24 +500,24 @@ def prepare_ip_adapter_image_embeds(
500500

501501
return ip_adapter_image_embeds
502502

503-
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
503+
def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
504504
latents = []
505-
for i in range(0, len(video), vae_batch_size):
506-
batch_video = video[i : i + vae_batch_size]
505+
for i in range(0, len(video), decode_chunk_size):
506+
batch_video = video[i : i + decode_chunk_size]
507507
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
508508
latents.append(batch_video)
509509
return torch.cat(latents)
510510

511511
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
512-
def decode_latents(self, latents, vae_batch_size: int = 16):
512+
def decode_latents(self, latents, decode_chunk_size: int = 16):
513513
latents = 1 / self.vae.config.scaling_factor * latents
514514

515515
batch_size, channels, num_frames, height, width = latents.shape
516516
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
517517

518518
video = []
519-
for i in range(0, latents.shape[0], vae_batch_size):
520-
batch_latents = latents[i : i + vae_batch_size]
519+
for i in range(0, latents.shape[0], decode_chunk_size):
520+
batch_latents = latents[i : i + decode_chunk_size]
521521
batch_latents = self.vae.decode(batch_latents).sample
522522
video.append(batch_latents)
523523

@@ -638,7 +638,7 @@ def prepare_latents(
638638
device,
639639
generator,
640640
latents=None,
641-
vae_batch_size: int = 16,
641+
decode_chunk_size: int = 16,
642642
):
643643
if latents is None:
644644
num_frames = video.shape[1]
@@ -673,10 +673,11 @@ def prepare_latents(
673673
)
674674

675675
init_latents = [
676-
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
676+
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
677+
for i in range(batch_size)
677678
]
678679
else:
679-
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]
680+
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
680681

681682
init_latents = torch.cat(init_latents, dim=0)
682683

@@ -761,7 +762,7 @@ def __call__(
761762
clip_skip: Optional[int] = None,
762763
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
763764
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
764-
vae_batch_size: int = 16,
765+
decode_chunk_size: int = 16,
765766
):
766767
r"""
767768
The call function to the pipeline for generation.
@@ -837,7 +838,7 @@ def __call__(
837838
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
838839
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
839840
`._callback_tensor_inputs` attribute of your pipeline class.
840-
vae_batch_size (`int`, defaults to `16`):
841+
decode_chunk_size (`int`, defaults to `16`):
841842
The number of frames to decode at a time when calling `decode_latents` method.
842843
843844
Examples:
@@ -940,7 +941,7 @@ def __call__(
940941
device=device,
941942
generator=generator,
942943
latents=latents,
943-
vae_batch_size=vae_batch_size,
944+
decode_chunk_size=decode_chunk_size,
944945
)
945946

946947
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1008,7 +1009,7 @@ def __call__(
10081009
if output_type == "latent":
10091010
video = latents
10101011
else:
1011-
video_tensor = self.decode_latents(latents, vae_batch_size)
1012+
video_tensor = self.decode_latents(latents, decode_chunk_size)
10121013
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
10131014

10141015
# 10. Offload all models

src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,15 +407,15 @@ def prepare_ip_adapter_image_embeds(
407407
return ip_adapter_image_embeds
408408

409409
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
410-
def decode_latents(self, latents, vae_batch_size: int = 16):
410+
def decode_latents(self, latents, decode_chunk_size: int = 16):
411411
latents = 1 / self.vae.config.scaling_factor * latents
412412

413413
batch_size, channels, num_frames, height, width = latents.shape
414414
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
415415

416416
video = []
417-
for i in range(0, latents.shape[0], vae_batch_size):
418-
batch_latents = latents[i : i + vae_batch_size]
417+
for i in range(0, latents.shape[0], decode_chunk_size):
418+
batch_latents = latents[i : i + decode_chunk_size]
419419
batch_latents = self.vae.decode(batch_latents).sample
420420
video.append(batch_latents)
421421

@@ -588,7 +588,7 @@ def __call__(
588588
clip_skip: Optional[int] = None,
589589
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
590590
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
591-
vae_batch_size: int = 16,
591+
decode_chunk_size: int = 16,
592592
pag_scale: float = 3.0,
593593
pag_adaptive_scale: float = 0.0,
594594
):
@@ -847,7 +847,7 @@ def __call__(
847847
if output_type == "latent":
848848
video = latents
849849
else:
850-
video_tensor = self.decode_latents(latents, vae_batch_size)
850+
video_tensor = self.decode_latents(latents, decode_chunk_size)
851851
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
852852

853853
# 10. Offload all models

0 commit comments

Comments
 (0)