Skip to content

Commit 066465e

Browse files
committed
more cleanup 🧹
1 parent b6be0ba commit 066465e

File tree

2 files changed

+7
-92
lines changed

2 files changed

+7
-92
lines changed

examples/community/matryoshka.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
USE_PEFT_BACKEND,
8181
BaseOutput,
8282
deprecate,
83-
is_torch_version,
8483
is_torch_xla_available,
8584
logging,
8685
replace_example_docstring,
@@ -869,23 +868,7 @@ def forward(
869868

870869
for i, (resnet, attn) in enumerate(blocks):
871870
if torch.is_grad_enabled() and self.gradient_checkpointing:
872-
873-
def create_custom_forward(module, return_dict=None):
874-
def custom_forward(*inputs):
875-
if return_dict is not None:
876-
return module(*inputs, return_dict=return_dict)
877-
else:
878-
return module(*inputs)
879-
880-
return custom_forward
881-
882-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883-
hidden_states = torch.utils.checkpoint.checkpoint(
884-
create_custom_forward(resnet),
885-
hidden_states,
886-
temb,
887-
**ckpt_kwargs,
888-
)
871+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
889872
hidden_states = attn(
890873
hidden_states,
891874
encoder_hidden_states=encoder_hidden_states,
@@ -1030,17 +1013,6 @@ def forward(
10301013
hidden_states = self.resnets[0](hidden_states, temb)
10311014
for attn, resnet in zip(self.attentions, self.resnets[1:]):
10321015
if torch.is_grad_enabled() and self.gradient_checkpointing:
1033-
1034-
def create_custom_forward(module, return_dict=None):
1035-
def custom_forward(*inputs):
1036-
if return_dict is not None:
1037-
return module(*inputs, return_dict=return_dict)
1038-
else:
1039-
return module(*inputs)
1040-
1041-
return custom_forward
1042-
1043-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
10441016
hidden_states = attn(
10451017
hidden_states,
10461018
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
10491021
encoder_attention_mask=encoder_attention_mask,
10501022
return_dict=False,
10511023
)[0]
1052-
hidden_states = torch.utils.checkpoint.checkpoint(
1053-
create_custom_forward(resnet),
1054-
hidden_states,
1055-
temb,
1056-
**ckpt_kwargs,
1057-
)
1024+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
10581025
else:
10591026
hidden_states = attn(
10601027
hidden_states,
@@ -1192,23 +1159,7 @@ def forward(
11921159
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931160

11941161
if torch.is_grad_enabled() and self.gradient_checkpointing:
1195-
1196-
def create_custom_forward(module, return_dict=None):
1197-
def custom_forward(*inputs):
1198-
if return_dict is not None:
1199-
return module(*inputs, return_dict=return_dict)
1200-
else:
1201-
return module(*inputs)
1202-
1203-
return custom_forward
1204-
1205-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1206-
hidden_states = torch.utils.checkpoint.checkpoint(
1207-
create_custom_forward(resnet),
1208-
hidden_states,
1209-
temb,
1210-
**ckpt_kwargs,
1211-
)
1162+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
12121163
hidden_states = attn(
12131164
hidden_states,
12141165
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
12821233
]
12831234
)
12841235

1285-
def _set_gradient_checkpointing(self, module, value=False):
1286-
if hasattr(module, "gradient_checkpointing"):
1287-
module.gradient_checkpointing = value
1288-
12891236
def forward(
12901237
self,
12911238
hidden_states: torch.Tensor,
@@ -1365,27 +1312,15 @@ def forward(
13651312
# Blocks
13661313
for block in self.transformer_blocks:
13671314
if torch.is_grad_enabled() and self.gradient_checkpointing:
1368-
1369-
def create_custom_forward(module, return_dict=None):
1370-
def custom_forward(*inputs):
1371-
if return_dict is not None:
1372-
return module(*inputs, return_dict=return_dict)
1373-
else:
1374-
return module(*inputs)
1375-
1376-
return custom_forward
1377-
1378-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1379-
hidden_states = torch.utils.checkpoint.checkpoint(
1380-
create_custom_forward(block),
1315+
hidden_states = self._gradient_checkpointing_func(
1316+
block,
13811317
hidden_states,
13821318
attention_mask,
13831319
encoder_hidden_states,
13841320
encoder_attention_mask,
13851321
timestep,
13861322
cross_attention_kwargs,
13871323
class_labels,
1388-
**ckpt_kwargs,
13891324
)
13901325
else:
13911326
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
27242659
for module in self.children():
27252660
fn_recursive_set_attention_slice(module, reversed_slice_size)
27262661

2727-
def _set_gradient_checkpointing(self, module, value=False):
2728-
if hasattr(module, "gradient_checkpointing"):
2729-
module.gradient_checkpointing = value
2730-
27312662
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
27322663
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
27332664

examples/research_projects/pixart/controlnet_pixart_alpha.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from diffusers.models.attention import BasicTransformerBlock
99
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1010
from diffusers.models.modeling_utils import ModelMixin
11-
from diffusers.utils.torch_utils import is_torch_version
1211

1312

1413
class PixArtControlNetAdapterBlock(nn.Module):
@@ -151,10 +150,6 @@ def __init__(
151150
self.transformer = transformer
152151
self.controlnet = controlnet
153152

154-
def _set_gradient_checkpointing(self, module, value=False):
155-
if hasattr(module, "gradient_checkpointing"):
156-
module.gradient_checkpointing = value
157-
158153
def forward(
159154
self,
160155
hidden_states: torch.Tensor,
@@ -220,26 +215,15 @@ def forward(
220215
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
221216
exit(1)
222217

223-
def create_custom_forward(module, return_dict=None):
224-
def custom_forward(*inputs):
225-
if return_dict is not None:
226-
return module(*inputs, return_dict=return_dict)
227-
else:
228-
return module(*inputs)
229-
230-
return custom_forward
231-
232-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
233-
hidden_states = torch.utils.checkpoint.checkpoint(
234-
create_custom_forward(block),
218+
hidden_states = self._gradient_checkpointing_func(
219+
block,
235220
hidden_states,
236221
attention_mask,
237222
encoder_hidden_states,
238223
encoder_attention_mask,
239224
timestep,
240225
cross_attention_kwargs,
241226
None,
242-
**ckpt_kwargs,
243227
)
244228
else:
245229
# the control nets are only used for the blocks 1 to self.blocks_num

0 commit comments

Comments
 (0)