Skip to content

Commit a7461b7

Browse files
authored
Merge branch 'main' into patch-1
2 parents 5836b22 + 8d6f6d6 commit a7461b7

File tree

77 files changed

+885
-23
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+885
-23
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ jobs:
265265
266266
- name: Run PyTorch CUDA tests
267267
env:
268-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
268+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
269269
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
270270
CUBLAS_WORKSPACE_CONFIG: :16:8
271271
run: |
@@ -505,7 +505,7 @@ jobs:
505505
# shell: arch -arch arm64 bash {0}
506506
# env:
507507
# HF_HOME: /System/Volumes/Data/mnt/cache
508-
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
508+
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
509509
# run: |
510510
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
511511
# --report-log=tests_torch_mps.log \
@@ -561,7 +561,7 @@ jobs:
561561
# shell: arch -arch arm64 bash {0}
562562
# env:
563563
# HF_HOME: /System/Volumes/Data/mnt/cache
564-
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
564+
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
565565
# run: |
566566
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
567567
# --report-log=tests_torch_mps.log \

.github/workflows/push_tests.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ jobs:
187187
188188
- name: Run Flax TPU tests
189189
env:
190-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
190+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
191191
run: |
192192
python -m pytest -n 0 \
193193
-s -v -k "Flax" \
@@ -235,7 +235,7 @@ jobs:
235235
236236
- name: Run ONNXRuntime CUDA tests
237237
env:
238-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
238+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
239239
run: |
240240
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
241241
-s -v -k "Onnx" \
@@ -283,7 +283,7 @@ jobs:
283283
python utils/print_env.py
284284
- name: Run example tests on GPU
285285
env:
286-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
286+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
287287
RUN_COMPILE: yes
288288
run: |
289289
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -326,7 +326,7 @@ jobs:
326326
python utils/print_env.py
327327
- name: Run example tests on GPU
328328
env:
329-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
329+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
330330
run: |
331331
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
332332
- name: Failure short reports
@@ -372,7 +372,7 @@ jobs:
372372
373373
- name: Run example tests on GPU
374374
env:
375-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
375+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
376376
run: |
377377
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
378378
python -m uv pip install timm

.github/workflows/release_tests_fast.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
python utils/print_env.py
8282
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
8383
env:
84-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
84+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
8585
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
8686
CUBLAS_WORKSPACE_CONFIG: :16:8
8787
run: |
@@ -135,7 +135,7 @@ jobs:
135135
136136
- name: Run PyTorch CUDA tests
137137
env:
138-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
138+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
139139
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
140140
CUBLAS_WORKSPACE_CONFIG: :16:8
141141
run: |
@@ -186,7 +186,7 @@ jobs:
186186
187187
- name: Run PyTorch CUDA tests
188188
env:
189-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
189+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
190190
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
191191
CUBLAS_WORKSPACE_CONFIG: :16:8
192192
run: |
@@ -241,7 +241,7 @@ jobs:
241241
242242
- name: Run slow Flax TPU tests
243243
env:
244-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
244+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
245245
run: |
246246
python -m pytest -n 0 \
247247
-s -v -k "Flax" \
@@ -289,7 +289,7 @@ jobs:
289289
290290
- name: Run slow ONNXRuntime CUDA tests
291291
env:
292-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
292+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
293293
run: |
294294
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
295295
-s -v -k "Onnx" \
@@ -337,7 +337,7 @@ jobs:
337337
python utils/print_env.py
338338
- name: Run example tests on GPU
339339
env:
340-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
340+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
341341
RUN_COMPILE: yes
342342
run: |
343343
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -380,7 +380,7 @@ jobs:
380380
python utils/print_env.py
381381
- name: Run example tests on GPU
382382
env:
383-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
383+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
384384
run: |
385385
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
386386
- name: Failure short reports
@@ -426,7 +426,7 @@ jobs:
426426
427427
- name: Run example tests on GPU
428428
env:
429-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
429+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
430430
run: |
431431
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
432432
python -m uv pip install timm

docs/source/en/api/utilities.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
4141
## randn_tensor
4242

4343
[[autodoc]] utils.torch_utils.randn_tensor
44+
45+
## apply_layerwise_casting
46+
47+
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting

docs/source/en/optimization/memory.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## FP8 layerwise weight-casting
162+
163+
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
164+
165+
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
166+
167+
```python
168+
import torch
169+
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
170+
from diffusers.utils import export_to_video
171+
172+
model_id = "THUDM/CogVideoX-5b"
173+
174+
# Load the model in bfloat16 and enable layerwise casting
175+
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
176+
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
177+
178+
# Load the pipeline
179+
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
180+
pipe.to("cuda")
181+
182+
prompt = (
183+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
184+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
185+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
186+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
187+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
188+
"atmosphere of this unique musical performance."
189+
)
190+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
191+
export_to_video(video, "output.mp4", fps=8)
192+
```
193+
194+
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
195+
196+
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
197+
161198
## Channels-last memory format
162199

163200
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.

src/diffusers/hooks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/hooks.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Optional, Tuple
17+
18+
import torch
19+
20+
from ..utils.logging import get_logger
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class ModelHook:
27+
r"""
28+
A hook that contains callbacks to be executed just before and after the forward method of a model.
29+
"""
30+
31+
_is_stateful = False
32+
33+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
34+
r"""
35+
Hook that is executed when a model is initialized.
36+
37+
Args:
38+
module (`torch.nn.Module`):
39+
The module attached to this hook.
40+
"""
41+
return module
42+
43+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
44+
r"""
45+
Hook that is executed when a model is deinitalized.
46+
47+
Args:
48+
module (`torch.nn.Module`):
49+
The module attached to this hook.
50+
"""
51+
module.forward = module._old_forward
52+
del module._old_forward
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56+
r"""
57+
Hook that is executed just before the forward method of the model.
58+
59+
Args:
60+
module (`torch.nn.Module`):
61+
The module whose forward pass will be executed just after this event.
62+
args (`Tuple[Any]`):
63+
The positional arguments passed to the module.
64+
kwargs (`Dict[Str, Any]`):
65+
The keyword arguments passed to the module.
66+
Returns:
67+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
68+
A tuple with the treated `args` and `kwargs`.
69+
"""
70+
return args, kwargs
71+
72+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
73+
r"""
74+
Hook that is executed just after the forward method of the model.
75+
76+
Args:
77+
module (`torch.nn.Module`):
78+
The module whose forward pass been executed just before this event.
79+
output (`Any`):
80+
The output of the module.
81+
Returns:
82+
`Any`: The processed `output`.
83+
"""
84+
return output
85+
86+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
87+
r"""
88+
Hook that is executed when the hook is detached from a module.
89+
90+
Args:
91+
module (`torch.nn.Module`):
92+
The module detached from this hook.
93+
"""
94+
return module
95+
96+
def reset_state(self, module: torch.nn.Module):
97+
if self._is_stateful:
98+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
99+
return module
100+
101+
102+
class HookRegistry:
103+
def __init__(self, module_ref: torch.nn.Module) -> None:
104+
super().__init__()
105+
106+
self.hooks: Dict[str, ModelHook] = {}
107+
108+
self._module_ref = module_ref
109+
self._hook_order = []
110+
111+
def register_hook(self, hook: ModelHook, name: str) -> None:
112+
if name in self.hooks.keys():
113+
logger.warning(f"Hook with name {name} already exists, replacing it.")
114+
115+
if hasattr(self._module_ref, "_old_forward"):
116+
old_forward = self._module_ref._old_forward
117+
else:
118+
old_forward = self._module_ref.forward
119+
self._module_ref._old_forward = self._module_ref.forward
120+
121+
self._module_ref = hook.initialize_hook(self._module_ref)
122+
123+
if hasattr(hook, "new_forward"):
124+
rewritten_forward = hook.new_forward
125+
126+
def new_forward(module, *args, **kwargs):
127+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
128+
output = rewritten_forward(module, *args, **kwargs)
129+
return hook.post_forward(module, output)
130+
else:
131+
132+
def new_forward(module, *args, **kwargs):
133+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
134+
output = old_forward(*args, **kwargs)
135+
return hook.post_forward(module, output)
136+
137+
self._module_ref.forward = functools.update_wrapper(
138+
functools.partial(new_forward, self._module_ref), old_forward
139+
)
140+
141+
self.hooks[name] = hook
142+
self._hook_order.append(name)
143+
144+
def get_hook(self, name: str) -> Optional[ModelHook]:
145+
if name not in self.hooks.keys():
146+
return None
147+
return self.hooks[name]
148+
149+
def remove_hook(self, name: str, recurse: bool = True) -> None:
150+
if name in self.hooks.keys():
151+
hook = self.hooks[name]
152+
self._module_ref = hook.deinitalize_hook(self._module_ref)
153+
del self.hooks[name]
154+
self._hook_order.remove(name)
155+
156+
if recurse:
157+
for module_name, module in self._module_ref.named_modules():
158+
if module_name == "":
159+
continue
160+
if hasattr(module, "_diffusers_hook"):
161+
module._diffusers_hook.remove_hook(name, recurse=False)
162+
163+
def reset_stateful_hooks(self, recurse: bool = True) -> None:
164+
for hook_name in self._hook_order:
165+
hook = self.hooks[hook_name]
166+
if hook._is_stateful:
167+
hook.reset_state(self._module_ref)
168+
169+
if recurse:
170+
for module_name, module in self._module_ref.named_modules():
171+
if module_name == "":
172+
continue
173+
if hasattr(module, "_diffusers_hook"):
174+
module._diffusers_hook.reset_stateful_hooks(recurse=False)
175+
176+
@classmethod
177+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
178+
if not hasattr(module, "_diffusers_hook"):
179+
module._diffusers_hook = cls(module)
180+
return module._diffusers_hook
181+
182+
def __repr__(self) -> str:
183+
hook_repr = ""
184+
for i, hook_name in enumerate(self._hook_order):
185+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
186+
if i < len(self._hook_order) - 1:
187+
hook_repr += "\n"
188+
return f"HookRegistry(\n{hook_repr}\n)"

0 commit comments

Comments
 (0)