Skip to content

Commit f2fb771

Browse files
authored
Merge branch 'main' into fix_lumina_pipe
2 parents a5be828 + 9a147b8 commit f2fb771

File tree

50 files changed

+1405
-26
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1405
-26
lines changed

docs/source/en/api/utilities.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
4545
## apply_layerwise_casting
4646

4747
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
48+
49+
## apply_group_offloading
50+
51+
[[autodoc]] hooks.group_offloading.apply_group_offloading

docs/source/en/optimization/memory.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## Group offloading
162+
163+
Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
164+
165+
To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:
166+
167+
```python
168+
import torch
169+
from diffusers import CogVideoXPipeline
170+
from diffusers.hooks import apply_group_offloading
171+
from diffusers.utils import export_to_video
172+
173+
# Load the pipeline
174+
onload_device = torch.device("cuda")
175+
offload_device = torch.device("cpu")
176+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
177+
178+
# We can utilize the enable_group_offload method for Diffusers model implementations
179+
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
180+
181+
# For any other model implementations, the apply_group_offloading function can be used
182+
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
183+
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
184+
185+
prompt = (
186+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
187+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
188+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
189+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
190+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
191+
"atmosphere of this unique musical performance."
192+
)
193+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
194+
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
195+
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
196+
export_to_video(video, "output.mp4", fps=8)
197+
```
198+
199+
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
200+
161201
## FP8 layerwise weight-casting
162202

163203
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,7 @@ pipe.delete_adapters("toy")
221221
pipe.get_active_adapters()
222222
["pixel"]
223223
```
224+
225+
## PeftInputAutocastDisableHook
226+
227+
[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
if is_torch_available():
5+
from .group_offloading import apply_group_offloading
56
from .hooks import HookRegistry, ModelHook
67
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
78
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/hooks/group_offloading.py

Lines changed: 678 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/hooks/layerwise_casting.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
import torch
1919

20-
from ..utils import get_logger
20+
from ..utils import get_logger, is_peft_available, is_peft_version
2121
from .hooks import HookRegistry, ModelHook
2222

2323

2424
logger = get_logger(__name__) # pylint: disable=invalid-name
2525

2626

2727
# fmt: off
28+
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
29+
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
2830
SUPPORTED_PYTORCH_LAYERS = (
2931
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3032
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -34,6 +36,11 @@
3436
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
3537
# fmt: on
3638

39+
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
40+
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
41+
from peft.helpers import disable_input_dtype_casting
42+
from peft.tuners.tuners_utils import BaseTunerLayer
43+
3744

3845
class LayerwiseCastingHook(ModelHook):
3946
r"""
@@ -70,6 +77,32 @@ def post_forward(self, module: torch.nn.Module, output):
7077
return output
7178

7279

80+
class PeftInputAutocastDisableHook(ModelHook):
81+
r"""
82+
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
83+
casts the inputs to the weight dtype of the module, which can lead to precision loss.
84+
85+
The reasons for needing this are:
86+
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
87+
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
88+
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
89+
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
90+
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
91+
hoping to achieve:
92+
1. Making forward implementations independent of device/dtype casting operations as much as possible.
93+
2. Peforming inference without losing information from casting to different precisions. With the current
94+
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
95+
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
96+
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
97+
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
98+
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
99+
"""
100+
101+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
102+
with disable_input_dtype_casting(module):
103+
return self.fn_ref.original_forward(*args, **kwargs)
104+
105+
73106
def apply_layerwise_casting(
74107
module: torch.nn.Module,
75108
storage_dtype: torch.dtype,
@@ -134,6 +167,7 @@ def apply_layerwise_casting(
134167
skip_modules_classes,
135168
non_blocking,
136169
)
170+
_disable_peft_input_autocast(module)
137171

138172

139173
def _apply_layerwise_casting(
@@ -188,4 +222,24 @@ def apply_layerwise_casting_hook(
188222
"""
189223
registry = HookRegistry.check_if_exists_or_initialize(module)
190224
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
191-
registry.register_hook(hook, "layerwise_casting")
225+
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
226+
227+
228+
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
229+
for submodule in module.modules():
230+
if (
231+
hasattr(submodule, "_diffusers_hook")
232+
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
233+
):
234+
return True
235+
return False
236+
237+
238+
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
239+
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
240+
return
241+
for submodule in module.modules():
242+
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
243+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
244+
hook = PeftInputAutocastDisableHook()
245+
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)

src/diffusers/models/autoencoders/autoencoder_oobleck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
317317
"""
318318

319319
_supports_gradient_checkpointing = False
320+
_supports_group_offloading = False
320321

321322
@register_to_config
322323
def __init__(

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6868
```
6969
"""
7070

71+
_supports_group_offloading = False
72+
7173
@register_to_config
7274
def __init__(
7375
self,

src/diffusers/models/autoencoders/vq_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
7272
"""
7373

7474
_skip_layerwise_casting_patterns = ["quantize"]
75+
_supports_group_offloading = False
7576

7677
@register_to_config
7778
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from typing_extensions import Self
3535

3636
from .. import __version__
37-
from ..hooks import apply_layerwise_casting
37+
from ..hooks import apply_group_offloading, apply_layerwise_casting
3838
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
3939
from ..quantizers.quantization_config import QuantizationMethod
4040
from ..utils import (
@@ -87,7 +87,17 @@
8787

8888

8989
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
90+
from ..hooks.group_offloading import _get_group_onload_device
91+
92+
try:
93+
# Try to get the onload device from the group offloading hook
94+
return _get_group_onload_device(parameter)
95+
except ValueError:
96+
pass
97+
9098
try:
99+
# If the onload device is not available due to no group offloading hooks, try to get the device
100+
# from the first parameter or buffer
91101
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
92102
return next(parameters_and_buffers).device
93103
except StopIteration:
@@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
166176
_no_split_modules = None
167177
_keep_in_fp32_modules = None
168178
_skip_layerwise_casting_patterns = None
179+
_supports_group_offloading = True
169180

170181
def __init__(self):
171182
super().__init__()
@@ -437,6 +448,55 @@ def enable_layerwise_casting(
437448
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
438449
)
439450

451+
def enable_group_offload(
452+
self,
453+
onload_device: torch.device,
454+
offload_device: torch.device = torch.device("cpu"),
455+
offload_type: str = "block_level",
456+
num_blocks_per_group: Optional[int] = None,
457+
non_blocking: bool = False,
458+
use_stream: bool = False,
459+
) -> None:
460+
r"""
461+
Activates group offloading for the current model.
462+
463+
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
464+
465+
Example:
466+
467+
```python
468+
>>> from diffusers import CogVideoXTransformer3DModel
469+
470+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
471+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
472+
... )
473+
474+
>>> transformer.enable_group_offload(
475+
... onload_device=torch.device("cuda"),
476+
... offload_device=torch.device("cpu"),
477+
... offload_type="leaf_level",
478+
... use_stream=True,
479+
... )
480+
```
481+
"""
482+
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
483+
msg = (
484+
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
485+
"forward pass is executed with tiling enabled. Please make sure to either:\n"
486+
"1. Run a forward pass with small input shapes.\n"
487+
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
488+
)
489+
logger.warning(msg)
490+
if not self._supports_group_offloading:
491+
raise ValueError(
492+
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
493+
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
494+
f"open an issue at https://github.com/huggingface/diffusers/issues."
495+
)
496+
apply_group_offloading(
497+
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
498+
)
499+
440500
def save_pretrained(
441501
self,
442502
save_directory: Union[str, os.PathLike],
@@ -1170,6 +1230,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11701230
# Adapted from `transformers`.
11711231
@wraps(torch.nn.Module.cuda)
11721232
def cuda(self, *args, **kwargs):
1233+
from ..hooks.group_offloading import _is_group_offload_enabled
1234+
11731235
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
11741236
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
11751237
if getattr(self, "is_loaded_in_8bit", False):
@@ -1182,13 +1244,34 @@ def cuda(self, *args, **kwargs):
11821244
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
11831245
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
11841246
)
1247+
1248+
# Checks if group offloading is enabled
1249+
if _is_group_offload_enabled(self):
1250+
logger.warning(
1251+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
1252+
)
1253+
return self
1254+
11851255
return super().cuda(*args, **kwargs)
11861256

11871257
# Adapted from `transformers`.
11881258
@wraps(torch.nn.Module.to)
11891259
def to(self, *args, **kwargs):
1260+
from ..hooks.group_offloading import _is_group_offload_enabled
1261+
1262+
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
11901263
dtype_present_in_args = "dtype" in kwargs
11911264

1265+
# Try converting arguments to torch.device in case they are passed as strings
1266+
for arg in args:
1267+
if not isinstance(arg, str):
1268+
continue
1269+
try:
1270+
torch.device(arg)
1271+
device_arg_or_kwarg_present = True
1272+
except RuntimeError:
1273+
pass
1274+
11921275
if not dtype_present_in_args:
11931276
for arg in args:
11941277
if isinstance(arg, torch.dtype):
@@ -1213,6 +1296,13 @@ def to(self, *args, **kwargs):
12131296
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
12141297
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
12151298
)
1299+
1300+
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
1301+
logger.warning(
1302+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
1303+
)
1304+
return self
1305+
12161306
return super().to(*args, **kwargs)
12171307

12181308
# Taken from `transformers`.

0 commit comments

Comments
 (0)