Skip to content

Commit b79928d

Browse files
committed
update
1 parent 328e0d2 commit b79928d

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import safetensors
2929
import torch
30+
import torch.utils.checkpoint
3031
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
3132
from huggingface_hub.utils import validate_hf_hub_args
3233
from torch import Tensor, nn
@@ -154,6 +155,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
154155
def __init__(self):
155156
super().__init__()
156157

158+
self._gradient_checkpointing_func = None
159+
157160
def __getattr__(self, name: str) -> Any:
158161
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
159162
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
@@ -179,22 +182,55 @@ def is_gradient_checkpointing(self) -> bool:
179182
"""
180183
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
181184

182-
def enable_gradient_checkpointing(self) -> None:
185+
def enable_gradient_checkpointing(
186+
self,
187+
gradient_checkpointing_func: Optional[Callable] = None,
188+
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
189+
) -> None:
183190
"""
184191
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
185192
*checkpoint activations* in other frameworks).
186193
"""
187194
if not self._supports_gradient_checkpointing:
188-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
189-
self.apply(partial(self._set_gradient_checkpointing, value=True))
195+
raise ValueError(
196+
f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
197+
f"`_supports_gradient_checkpointing` to `True` in the class definition."
198+
)
199+
200+
user_provided_gradient_checkpointing_func = gradient_checkpointing_func is not None
201+
if gradient_checkpointing_func is None:
202+
203+
def _gradient_checkpointing_func(module, *args):
204+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
205+
return torch.utils.checkpoint.checkpoint(
206+
module.__call__,
207+
*args,
208+
**ckpt_kwargs,
209+
)
210+
211+
gradient_checkpointing_func = _gradient_checkpointing_func
212+
213+
if gradient_checkpointing_kwargs is None:
214+
gradient_checkpointing_kwargs = {}
215+
216+
if (
217+
not user_provided_gradient_checkpointing_func
218+
and is_torch_version(">=", "1.11.0")
219+
and inspect.signature(gradient_checkpointing_func).parameters.get("use_reentrant") is not None
220+
):
221+
gradient_checkpointing_kwargs["use_reentrant"] = False
222+
223+
gradient_checkpointing_func = partial(gradient_checkpointing_func, **gradient_checkpointing_kwargs)
224+
225+
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
190226

191227
def disable_gradient_checkpointing(self) -> None:
192228
"""
193229
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
194230
*checkpoint activations* in other frameworks).
195231
"""
196232
if self._supports_gradient_checkpointing:
197-
self.apply(partial(self._set_gradient_checkpointing, value=False))
233+
self._set_gradient_checkpointing(enable=False)
198234

199235
def set_use_npu_flash_attention(self, valid: bool) -> None:
200236
r"""
@@ -1354,6 +1390,24 @@ def get_memory_footprint(self, return_buffers=True):
13541390
mem = mem + mem_bufs
13551391
return mem
13561392

1393+
def _set_gradient_checkpointing(
1394+
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
1395+
) -> None:
1396+
is_gradient_checkpointing_set = False
1397+
1398+
for name, module in self.named_modules():
1399+
if hasattr(module, "gradient_checkpointing"):
1400+
logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
1401+
module._gradient_checkpointing_func = gradient_checkpointing_func
1402+
module.gradient_checkpointing = enable
1403+
is_gradient_checkpointing_set = True
1404+
1405+
if not is_gradient_checkpointing_set:
1406+
raise ValueError(
1407+
f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to use a module that supports gradient checkpointing "
1408+
f"by creating a boolean attribute `gradient_checkpointing` in the module and setting it to `True`."
1409+
)
1410+
13571411
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
13581412
deprecated_attention_block_paths = []
13591413

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from ..attention import FeedForward
2828
from ..attention_processor import Attention
@@ -360,10 +360,6 @@ def __init__(
360360

361361
self.gradient_checkpointing = False
362362

363-
def _set_gradient_checkpointing(self, module, value=False):
364-
if hasattr(module, "gradient_checkpointing"):
365-
module.gradient_checkpointing = value
366-
367363
def forward(
368364
self,
369365
hidden_states: torch.Tensor,
@@ -426,15 +422,13 @@ def custom_forward(*inputs):
426422

427423
return custom_forward
428424

429-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
430-
hidden_states = torch.utils.checkpoint.checkpoint(
431-
create_custom_forward(block),
425+
hidden_states = self._gradient_checkpointing_func(
426+
block,
432427
hidden_states,
433428
encoder_hidden_states,
434429
temb,
435430
image_rotary_emb,
436431
encoder_attention_mask,
437-
**ckpt_kwargs,
438432
)
439433
else:
440434
hidden_states = block(

0 commit comments

Comments
 (0)