@@ -512,12 +512,26 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
512512        sample  =  self .temp_conv_in (sample )
513513        sample  =  sample  +  residual 
514514
515-         # Down blocks 
516-         for  down_block  in  self .down_blocks :
517-             sample  =  down_block (sample )
515+         if  self .gradient_checkpointing :
516+             def  create_custom_forward (module ):
517+                 def  custom_forward (* inputs ):
518+                     return  module (* inputs )
518519
519-         # Mid block 
520-         sample  =  self .mid_block (sample )
520+                 return  custom_forward 
521+ 
522+             # Down blocks 
523+             for  down_block  in  self .down_blocks :
524+                 sample  =  torch .utils .checkpoint .checkpoint (create_custom_forward (down_block ), sample )
525+ 
526+             # Mid block 
527+             sample  =  torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), sample )
528+         else :
529+             # Down blocks 
530+             for  down_block  in  self .down_blocks :
531+                 sample  =  down_block (sample )
532+ 
533+             # Mid block 
534+             sample  =  self .mid_block (sample )
521535
522536        # Post process 
523537        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
@@ -625,7 +639,6 @@ def __init__(
625639        self .temp_conv_out  =  nn .Conv3d (block_out_channels [0 ], block_out_channels [0 ], (3 , 1 , 1 ), padding = (1 , 0 , 0 ))
626640        self .conv_out  =  nn .Conv2d (block_out_channels [0 ], out_channels , 3 , padding = 1 )
627641
628-         # TODO(aryan): implement gradient checkpointing 
629642        self .gradient_checkpointing  =  False 
630643
631644    def  forward (self , sample : torch .Tensor ) ->  torch .Tensor :
@@ -641,13 +654,34 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
641654
642655        upscale_dtype  =  next (iter (self .up_blocks .parameters ())).dtype 
643656
644-         # Mid block 
645-         sample  =  self .mid_block (sample )
646-         sample  =  sample .to (upscale_dtype )
657+         if  self .gradient_checkpointing :
658+             def  create_custom_forward (module ):
659+                 def  custom_forward (* inputs ):
660+                     return  module (* inputs )
661+ 
662+                 return  custom_forward 
663+ 
664+             # Mid block 
665+             sample  =  torch .utils .checkpoint .checkpoint (
666+                 create_custom_forward (self .mid_block ),
667+                 sample 
668+             )
669+ 
670+             # Up blocks 
671+             for  up_block  in  self .up_blocks :
672+                 sample  =  torch .utils .checkpoint .checkpoint (
673+                     create_custom_forward (up_block ),
674+                     sample 
675+                 )
647676
648-         # Up blocks 
649-         for  up_block  in  self .up_blocks :
650-             sample  =  up_block (sample )
677+         else :
678+             # Mid block 
679+             sample  =  self .mid_block (sample )
680+             sample  =  sample .to (upscale_dtype )
681+ 
682+             # Up blocks 
683+             for  up_block  in  self .up_blocks :
684+                 sample  =  up_block (sample )
651685
652686        # Post process 
653687        sample  =  sample .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
@@ -783,6 +817,10 @@ def __init__(
783817            self .sample_size  -  self .tile_overlap [1 ],
784818        )  # (16, 112, 192) 
785819
820+     def  _set_gradient_checkpointing (self , module , value = False ):
821+         if  isinstance (module , (AllegroEncoder3D , AllegroDecoder3D )):
822+             module .gradient_checkpointing  =  value 
823+ 
786824    def  encode (
787825        self , input_imgs : torch .Tensor , return_dict : bool  =  True , local_batch_size = 1 
788826    ) ->  Union [AutoencoderKLOutput , Tuple [DiagonalGaussianDistribution ]]:
0 commit comments