Skip to content

Commit 015cc78

Browse files
committed
gradient checkpointing
1 parent c76dc5a commit 015cc78

File tree

2 files changed

+80
-22
lines changed

2 files changed

+80
-22
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -512,12 +512,26 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
512512
sample = self.temp_conv_in(sample)
513513
sample = sample + residual
514514

515-
# Down blocks
516-
for down_block in self.down_blocks:
517-
sample = down_block(sample)
515+
if self.gradient_checkpointing:
516+
def create_custom_forward(module):
517+
def custom_forward(*inputs):
518+
return module(*inputs)
518519

519-
# Mid block
520-
sample = self.mid_block(sample)
520+
return custom_forward
521+
522+
# Down blocks
523+
for down_block in self.down_blocks:
524+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
525+
526+
# Mid block
527+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
528+
else:
529+
# Down blocks
530+
for down_block in self.down_blocks:
531+
sample = down_block(sample)
532+
533+
# Mid block
534+
sample = self.mid_block(sample)
521535

522536
# Post process
523537
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -625,7 +639,6 @@ def __init__(
625639
self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
626640
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
627641

628-
# TODO(aryan): implement gradient checkpointing
629642
self.gradient_checkpointing = False
630643

631644
def forward(self, sample: torch.Tensor) -> torch.Tensor:
@@ -641,13 +654,34 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
641654

642655
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
643656

644-
# Mid block
645-
sample = self.mid_block(sample)
646-
sample = sample.to(upscale_dtype)
657+
if self.gradient_checkpointing:
658+
def create_custom_forward(module):
659+
def custom_forward(*inputs):
660+
return module(*inputs)
661+
662+
return custom_forward
663+
664+
# Mid block
665+
sample = torch.utils.checkpoint.checkpoint(
666+
create_custom_forward(self.mid_block),
667+
sample
668+
)
669+
670+
# Up blocks
671+
for up_block in self.up_blocks:
672+
sample = torch.utils.checkpoint.checkpoint(
673+
create_custom_forward(up_block),
674+
sample
675+
)
647676

648-
# Up blocks
649-
for up_block in self.up_blocks:
650-
sample = up_block(sample)
677+
else:
678+
# Mid block
679+
sample = self.mid_block(sample)
680+
sample = sample.to(upscale_dtype)
681+
682+
# Up blocks
683+
for up_block in self.up_blocks:
684+
sample = up_block(sample)
651685

652686
# Post process
653687
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -783,6 +817,10 @@ def __init__(
783817
self.sample_size - self.tile_overlap[1],
784818
) # (16, 112, 192)
785819

820+
def _set_gradient_checkpointing(self, module, value=False):
821+
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
822+
module.gradient_checkpointing = value
823+
786824
def encode(
787825
self, input_imgs: torch.Tensor, return_dict: bool = True, local_batch_size=1
788826
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional, Tuple
16+
from typing import Any, Dict, Optional, Tuple
1717

1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23-
from ...utils import logging
23+
from ...utils import is_torch_version, logging
2424
from ...utils.torch_utils import maybe_allow_in_graph
2525
from ..attention import FeedForward
2626
from ..attention_processor import AllegroAttnProcessor2_0, Attention
@@ -335,14 +335,34 @@ def forward(
335335

336336
for i, block in enumerate(self.transformer_blocks):
337337
# TODO(aryan): Implement gradient checkpointing
338-
hidden_states = block(
339-
hidden_states=hidden_states,
340-
encoder_hidden_states=encoder_hidden_states,
341-
temb=timestep,
342-
attention_mask=attention_mask,
343-
encoder_attention_mask=encoder_attention_mask,
344-
image_rotary_emb=image_rotary_emb,
345-
)
338+
if self.gradient_checkpointing:
339+
340+
def create_custom_forward(module):
341+
def custom_forward(*inputs):
342+
return module(*inputs)
343+
344+
return custom_forward
345+
346+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
347+
hidden_states = torch.utils.checkpoint.checkpoint(
348+
create_custom_forward(block),
349+
hidden_states,
350+
encoder_hidden_states,
351+
timestep,
352+
attention_mask,
353+
encoder_attention_mask,
354+
image_rotary_emb,
355+
**ckpt_kwargs,
356+
)
357+
else:
358+
hidden_states = block(
359+
hidden_states=hidden_states,
360+
encoder_hidden_states=encoder_hidden_states,
361+
temb=timestep,
362+
attention_mask=attention_mask,
363+
encoder_attention_mask=encoder_attention_mask,
364+
image_rotary_emb=image_rotary_emb,
365+
)
346366

347367
# 3. Output
348368
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)

0 commit comments

Comments
 (0)