Skip to content

Commit d0c3aae

Browse files
committed
update + cleanup 🧹
1 parent 3f7aa53 commit d0c3aae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+246
-1657
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def __init__(
138138
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139139
self.tile_overlap_factor = 0.25
140140

141-
def _set_gradient_checkpointing(self, module, value=False):
142-
if isinstance(module, (Encoder, Decoder)):
143-
module.gradient_checkpointing = value
144-
145141
def enable_tiling(self, use_tiling: bool = True):
146142
r"""
147143
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
507507
sample = sample + residual
508508

509509
if torch.is_grad_enabled() and self.gradient_checkpointing:
510-
511-
def create_custom_forward(module):
512-
def custom_forward(*inputs):
513-
return module(*inputs)
514-
515-
return custom_forward
516-
517510
# Down blocks
518511
for down_block in self.down_blocks:
519-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
512+
sample = self._gradient_checkpointing_func(down_block, sample)
520513

521514
# Mid block
522-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
515+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
523516
else:
524517
# Down blocks
525518
for down_block in self.down_blocks:
@@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
647640
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648641

649642
if torch.is_grad_enabled() and self.gradient_checkpointing:
650-
651-
def create_custom_forward(module):
652-
def custom_forward(*inputs):
653-
return module(*inputs)
654-
655-
return custom_forward
656-
657643
# Mid block
658-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
644+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
659645

660646
# Up blocks
661647
for up_block in self.up_blocks:
662-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
648+
sample = self._gradient_checkpointing_func(up_block, sample)
663649

664650
else:
665651
# Mid block
@@ -809,10 +795,6 @@ def __init__(
809795
sample_size - self.tile_overlap_w,
810796
)
811797

812-
def _set_gradient_checkpointing(self, module, value=False):
813-
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
814-
module.gradient_checkpointing = value
815-
816798
def enable_tiling(self) -> None:
817799
r"""
818800
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -421,15 +421,8 @@ def forward(
421421
conv_cache_key = f"resnet_{i}"
422422

423423
if torch.is_grad_enabled() and self.gradient_checkpointing:
424-
425-
def create_custom_forward(module):
426-
def create_forward(*inputs):
427-
return module(*inputs)
428-
429-
return create_forward
430-
431-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
432-
create_custom_forward(resnet),
424+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
425+
resnet,
433426
hidden_states,
434427
temb,
435428
zq,
@@ -523,15 +516,8 @@ def forward(
523516
conv_cache_key = f"resnet_{i}"
524517

525518
if torch.is_grad_enabled() and self.gradient_checkpointing:
526-
527-
def create_custom_forward(module):
528-
def create_forward(*inputs):
529-
return module(*inputs)
530-
531-
return create_forward
532-
533-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
519+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
520+
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535521
)
536522
else:
537523
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -637,15 +623,8 @@ def forward(
637623
conv_cache_key = f"resnet_{i}"
638624

639625
if torch.is_grad_enabled() and self.gradient_checkpointing:
640-
641-
def create_custom_forward(module):
642-
def create_forward(*inputs):
643-
return module(*inputs)
644-
645-
return create_forward
646-
647-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
648-
create_custom_forward(resnet),
626+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
627+
resnet,
649628
hidden_states,
650629
temb,
651630
zq,
@@ -774,27 +753,20 @@ def forward(
774753
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775754

776755
if torch.is_grad_enabled() and self.gradient_checkpointing:
777-
778-
def create_custom_forward(module):
779-
def custom_forward(*inputs):
780-
return module(*inputs)
781-
782-
return custom_forward
783-
784756
# 1. Down
785757
for i, down_block in enumerate(self.down_blocks):
786758
conv_cache_key = f"down_block_{i}"
787-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
788-
create_custom_forward(down_block),
759+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
760+
down_block,
789761
hidden_states,
790762
temb,
791763
None,
792764
conv_cache.get(conv_cache_key),
793765
)
794766

795767
# 2. Mid
796-
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
797-
create_custom_forward(self.mid_block),
768+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
769+
self.mid_block,
798770
hidden_states,
799771
temb,
800772
None,
@@ -940,16 +912,9 @@ def forward(
940912
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941913

942914
if torch.is_grad_enabled() and self.gradient_checkpointing:
943-
944-
def create_custom_forward(module):
945-
def custom_forward(*inputs):
946-
return module(*inputs)
947-
948-
return custom_forward
949-
950915
# 1. Mid
951-
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
952-
create_custom_forward(self.mid_block),
916+
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
917+
self.mid_block,
953918
hidden_states,
954919
temb,
955920
sample,
@@ -959,8 +924,8 @@ def custom_forward(*inputs):
959924
# 2. Up
960925
for i, up_block in enumerate(self.up_blocks):
961926
conv_cache_key = f"up_block_{i}"
962-
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
963-
create_custom_forward(up_block),
927+
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
928+
up_block,
964929
hidden_states,
965930
temb,
966931
sample,
@@ -1122,10 +1087,6 @@ def __init__(
11221087
self.tile_overlap_factor_height = 1 / 6
11231088
self.tile_overlap_factor_width = 1 / 5
11241089

1125-
def _set_gradient_checkpointing(self, module, value=False):
1126-
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1127-
module.gradient_checkpointing = value
1128-
11291090
def enable_tiling(
11301091
self,
11311092
tile_sample_min_height: Optional[int] = None,

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 10 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import numpy as np
1818
import torch
@@ -21,7 +21,7 @@
2121
import torch.utils.checkpoint
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
24-
from ...utils import is_torch_version, logging
24+
from ...utils import logging
2525
from ...utils.accelerate_utils import apply_forward_hook
2626
from ..activations import get_activation
2727
from ..attention_processor import Attention
@@ -252,21 +252,7 @@ def __init__(
252252

253253
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254254
if torch.is_grad_enabled() and self.gradient_checkpointing:
255-
256-
def create_custom_forward(module, return_dict=None):
257-
def custom_forward(*inputs):
258-
if return_dict is not None:
259-
return module(*inputs, return_dict=return_dict)
260-
else:
261-
return module(*inputs)
262-
263-
return custom_forward
264-
265-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
266-
267-
hidden_states = torch.utils.checkpoint.checkpoint(
268-
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
269-
)
255+
hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
270256

271257
for attn, resnet in zip(self.attentions, self.resnets[1:]):
272258
if attn is not None:
@@ -278,9 +264,7 @@ def custom_forward(*inputs):
278264
hidden_states = attn(hidden_states, attention_mask=attention_mask)
279265
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
280266

281-
hidden_states = torch.utils.checkpoint.checkpoint(
282-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
283-
)
267+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
284268

285269
else:
286270
hidden_states = self.resnets[0](hidden_states)
@@ -350,22 +334,8 @@ def __init__(
350334

351335
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
352336
if torch.is_grad_enabled() and self.gradient_checkpointing:
353-
354-
def create_custom_forward(module, return_dict=None):
355-
def custom_forward(*inputs):
356-
if return_dict is not None:
357-
return module(*inputs, return_dict=return_dict)
358-
else:
359-
return module(*inputs)
360-
361-
return custom_forward
362-
363-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
364-
365337
for resnet in self.resnets:
366-
hidden_states = torch.utils.checkpoint.checkpoint(
367-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
368-
)
338+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
369339
else:
370340
for resnet in self.resnets:
371341
hidden_states = resnet(hidden_states)
@@ -426,22 +396,8 @@ def __init__(
426396

427397
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428398
if torch.is_grad_enabled() and self.gradient_checkpointing:
429-
430-
def create_custom_forward(module, return_dict=None):
431-
def custom_forward(*inputs):
432-
if return_dict is not None:
433-
return module(*inputs, return_dict=return_dict)
434-
else:
435-
return module(*inputs)
436-
437-
return custom_forward
438-
439-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
440-
441399
for resnet in self.resnets:
442-
hidden_states = torch.utils.checkpoint.checkpoint(
443-
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
444-
)
400+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
445401

446402
else:
447403
for resnet in self.resnets:
@@ -545,26 +501,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
545501
hidden_states = self.conv_in(hidden_states)
546502

547503
if torch.is_grad_enabled() and self.gradient_checkpointing:
548-
549-
def create_custom_forward(module, return_dict=None):
550-
def custom_forward(*inputs):
551-
if return_dict is not None:
552-
return module(*inputs, return_dict=return_dict)
553-
else:
554-
return module(*inputs)
555-
556-
return custom_forward
557-
558-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
559-
560504
for down_block in self.down_blocks:
561-
hidden_states = torch.utils.checkpoint.checkpoint(
562-
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
563-
)
505+
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
564506

565-
hidden_states = torch.utils.checkpoint.checkpoint(
566-
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
567-
)
507+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
568508
else:
569509
for down_block in self.down_blocks:
570510
hidden_states = down_block(hidden_states)
@@ -667,26 +607,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
667607
hidden_states = self.conv_in(hidden_states)
668608

669609
if torch.is_grad_enabled() and self.gradient_checkpointing:
670-
671-
def create_custom_forward(module, return_dict=None):
672-
def custom_forward(*inputs):
673-
if return_dict is not None:
674-
return module(*inputs, return_dict=return_dict)
675-
else:
676-
return module(*inputs)
677-
678-
return custom_forward
679-
680-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
681-
682-
hidden_states = torch.utils.checkpoint.checkpoint(
683-
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
684-
)
610+
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
685611

686612
for up_block in self.up_blocks:
687-
hidden_states = torch.utils.checkpoint.checkpoint(
688-
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
689-
)
613+
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
690614
else:
691615
hidden_states = self.mid_block(hidden_states)
692616

@@ -800,10 +724,6 @@ def __init__(
800724
self.tile_sample_stride_width = 192
801725
self.tile_sample_stride_num_frames = 12
802726

803-
def _set_gradient_checkpointing(self, module, value=False):
804-
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
805-
module.gradient_checkpointing = value
806-
807727
def enable_tiling(
808728
self,
809729
tile_sample_min_height: Optional[int] = None,

0 commit comments

Comments
 (0)