Skip to content

Commit a12240f

Browse files
dg845sayakpaulVincentNeemiepatrickvonplatenDN6
authored andcommitted
Enable Gradient Checkpointing for UNet2DModel (New) (huggingface#7201)
* Port UNet2DModel gradient checkpointing code from huggingface#6718. --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Vincent Neemie <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent 47cd8c6 commit a12240f

File tree

7 files changed

+154
-14
lines changed

7 files changed

+154
-14
lines changed

src/diffusers/models/unets/unet_2d.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
9393
The number of discrete classes. The number of class types per each class is provided in tuple-like array.
9494
"""
9595

96+
_supports_gradient_checkpointing = True
97+
9698
@register_to_config
9799
def __init__(
98100
self,
@@ -267,6 +269,10 @@ def __init__(
267269
self.conv_act = nn.SiLU()
268270
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
269271

272+
def _set_gradient_checkpointing(self, module, value=False):
273+
if hasattr(module, "gradient_checkpointing"):
274+
module.gradient_checkpointing = value
275+
270276
def forward(
271277
self,
272278
sample: torch.Tensor,

src/diffusers/models/unets/unet_2d_blocks.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,35 @@ def __init__(
731731
self.attentions = nn.ModuleList(attentions)
732732
self.resnets = nn.ModuleList(resnets)
733733

734+
self.gradient_checkpointing = False
735+
734736
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
735737
hidden_states = self.resnets[0](hidden_states, temb)
736738
for attn, resnet in zip(self.attentions, self.resnets[1:]):
737-
if attn is not None:
738-
hidden_states = attn(hidden_states, temb=temb)
739-
hidden_states = resnet(hidden_states, temb)
739+
if torch.is_grad_enabled() and self.gradient_checkpointing:
740+
741+
def create_custom_forward(module, return_dict=None):
742+
def custom_forward(*inputs):
743+
if return_dict is not None:
744+
return module(*inputs, return_dict=return_dict)
745+
else:
746+
return module(*inputs)
747+
748+
return custom_forward
749+
750+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
751+
if attn is not None:
752+
hidden_states = attn(hidden_states, temb=temb)
753+
hidden_states = torch.utils.checkpoint.checkpoint(
754+
create_custom_forward(resnet),
755+
hidden_states,
756+
temb,
757+
**ckpt_kwargs,
758+
)
759+
else:
760+
if attn is not None:
761+
hidden_states = attn(hidden_states, temb=temb)
762+
hidden_states = resnet(hidden_states, temb)
740763

741764
return hidden_states
742765

@@ -1116,6 +1139,8 @@ def __init__(
11161139
else:
11171140
self.downsamplers = None
11181141

1142+
self.gradient_checkpointing = False
1143+
11191144
def forward(
11201145
self,
11211146
hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ def forward(
11301155
output_states = ()
11311156

11321157
for resnet, attn in zip(self.resnets, self.attentions):
1133-
hidden_states = resnet(hidden_states, temb)
1134-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
1135-
output_states = output_states + (hidden_states,)
1158+
if torch.is_grad_enabled() and self.gradient_checkpointing:
1159+
1160+
def create_custom_forward(module, return_dict=None):
1161+
def custom_forward(*inputs):
1162+
if return_dict is not None:
1163+
return module(*inputs, return_dict=return_dict)
1164+
else:
1165+
return module(*inputs)
1166+
1167+
return custom_forward
1168+
1169+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170+
hidden_states = torch.utils.checkpoint.checkpoint(
1171+
create_custom_forward(resnet),
1172+
hidden_states,
1173+
temb,
1174+
**ckpt_kwargs,
1175+
)
1176+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
1177+
output_states = output_states + (hidden_states,)
1178+
else:
1179+
hidden_states = resnet(hidden_states, temb)
1180+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
1181+
output_states = output_states + (hidden_states,)
11361182

11371183
if self.downsamplers is not None:
11381184
for downsampler in self.downsamplers:
@@ -2354,6 +2400,7 @@ def __init__(
23542400
else:
23552401
self.upsamplers = None
23562402

2403+
self.gradient_checkpointing = False
23572404
self.resolution_idx = resolution_idx
23582405

23592406
def forward(
@@ -2375,8 +2422,28 @@ def forward(
23752422
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
23762423
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
23772424

2378-
hidden_states = resnet(hidden_states, temb)
2379-
hidden_states = attn(hidden_states)
2425+
if torch.is_grad_enabled() and self.gradient_checkpointing:
2426+
2427+
def create_custom_forward(module, return_dict=None):
2428+
def custom_forward(*inputs):
2429+
if return_dict is not None:
2430+
return module(*inputs, return_dict=return_dict)
2431+
else:
2432+
return module(*inputs)
2433+
2434+
return custom_forward
2435+
2436+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2437+
hidden_states = torch.utils.checkpoint.checkpoint(
2438+
create_custom_forward(resnet),
2439+
hidden_states,
2440+
temb,
2441+
**ckpt_kwargs,
2442+
)
2443+
hidden_states = attn(hidden_states)
2444+
else:
2445+
hidden_states = resnet(hidden_states, temb)
2446+
hidden_states = attn(hidden_states)
23802447

23812448
if self.upsamplers is not None:
23822449
for upsampler in self.upsamplers:

src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,12 +2223,35 @@ def __init__(
22232223
self.attentions = nn.ModuleList(attentions)
22242224
self.resnets = nn.ModuleList(resnets)
22252225

2226+
self.gradient_checkpointing = False
2227+
22262228
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
22272229
hidden_states = self.resnets[0](hidden_states, temb)
22282230
for attn, resnet in zip(self.attentions, self.resnets[1:]):
2229-
if attn is not None:
2230-
hidden_states = attn(hidden_states, temb=temb)
2231-
hidden_states = resnet(hidden_states, temb)
2231+
if torch.is_grad_enabled() and self.gradient_checkpointing:
2232+
2233+
def create_custom_forward(module, return_dict=None):
2234+
def custom_forward(*inputs):
2235+
if return_dict is not None:
2236+
return module(*inputs, return_dict=return_dict)
2237+
else:
2238+
return module(*inputs)
2239+
2240+
return custom_forward
2241+
2242+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2243+
if attn is not None:
2244+
hidden_states = attn(hidden_states, temb=temb)
2245+
hidden_states = torch.utils.checkpoint.checkpoint(
2246+
create_custom_forward(resnet),
2247+
hidden_states,
2248+
temb,
2249+
**ckpt_kwargs,
2250+
)
2251+
else:
2252+
if attn is not None:
2253+
hidden_states = attn(hidden_states, temb=temb)
2254+
hidden_states = resnet(hidden_states, temb)
22322255

22332256
return hidden_states
22342257

tests/models/autoencoders/test_models_autoencoder_kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_enable_disable_slicing(self):
146146
)
147147

148148
def test_gradient_checkpointing_is_applied(self):
149-
expected_set = {"Decoder", "Encoder"}
149+
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
150150
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
151151

152152
def test_from_pretrained_hub(self):

tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def prepare_init_args_and_inputs_for_common(self):
6565
return init_dict, inputs_dict
6666

6767
def test_gradient_checkpointing_is_applied(self):
68-
expected_set = {"Encoder", "TemporalDecoder"}
68+
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
6969
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
7070

7171
@unittest.skip("Test unsupported.")

tests/models/test_modeling_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def test_enable_disable_gradient_checkpointing(self):
803803
self.assertFalse(model.is_gradient_checkpointing)
804804

805805
@require_torch_accelerator_with_training
806-
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
806+
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
807807
if not self.model_class._supports_gradient_checkpointing:
808808
return # Skip test if model does not support gradient checkpointing
809809

@@ -850,6 +850,8 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
850850
for name, param in named_params.items():
851851
if "post_quant_conv" in name:
852852
continue
853+
if name in skip:
854+
continue
853855
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
854856

855857
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")

tests/models/unets/test_models_unet_2d.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,23 @@ def test_mid_block_attn_groups(self):
105105
expected_shape = inputs_dict["sample"].shape
106106
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
107107

108+
def test_gradient_checkpointing_is_applied(self):
109+
expected_set = {
110+
"AttnUpBlock2D",
111+
"AttnDownBlock2D",
112+
"UNetMidBlock2D",
113+
"UpBlock2D",
114+
"DownBlock2D",
115+
}
116+
117+
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
118+
attention_head_dim = 8
119+
block_out_channels = (16, 32)
120+
121+
super().test_gradient_checkpointing_is_applied(
122+
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
123+
)
124+
108125

109126
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
110127
model_class = UNet2DModel
@@ -220,6 +237,17 @@ def test_output_pretrained(self):
220237

221238
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
222239

240+
def test_gradient_checkpointing_is_applied(self):
241+
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
242+
243+
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
244+
attention_head_dim = 32
245+
block_out_channels = (32, 64)
246+
247+
super().test_gradient_checkpointing_is_applied(
248+
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
249+
)
250+
223251

224252
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
225253
model_class = UNet2DModel
@@ -329,3 +357,17 @@ def test_output_pretrained_ve_large(self):
329357
def test_forward_with_norm_groups(self):
330358
# not required for this model
331359
pass
360+
361+
def test_gradient_checkpointing_is_applied(self):
362+
expected_set = {
363+
"UNetMidBlock2D",
364+
}
365+
366+
block_out_channels = (32, 64, 64, 64)
367+
368+
super().test_gradient_checkpointing_is_applied(
369+
expected_set=expected_set, block_out_channels=block_out_channels
370+
)
371+
372+
def test_effective_gradient_checkpointing(self):
373+
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})

0 commit comments

Comments
 (0)