Skip to content

Commit a9364bd

Browse files
committed
layerwise_upcasting -> layerwise_casting
1 parent 1d306b8 commit a9364bd

File tree

10 files changed

+65
-65
lines changed

10 files changed

+65
-65
lines changed

docs/source/en/api/utilities.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ Utility and helper functions for working with 🤗 Diffusers.
4242

4343
[[autodoc]] utils.torch_utils.randn_tensor
4444

45-
## apply_layerwise_upcasting
45+
## apply_layerwise_casting
4646

47-
[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting
47+
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting

docs/source/en/optimization/memory.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ from diffusers.utils import export_to_video
171171

172172
model_id = "THUDM/CogVideoX-5b"
173173

174-
# Load the model in bfloat16 and enable layerwise upcasting
174+
# Load the model in bfloat16 and enable layerwise casting
175175
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
176-
transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
176+
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
177177

178178
# Load the pipeline
179179
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
@@ -191,9 +191,9 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
191191
export_to_video(video, "output.mp4", fps=8)
192192
```
193193

194-
In the above example, layerwise upcasting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
194+
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
195195

196-
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`].
196+
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
197197

198198
## Channels-last memory format
199199

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
if is_torch_available():
5-
from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook
5+
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/layerwise_upcasting.py renamed to src/diffusers/hooks/layerwise_casting.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# fmt: on
3636

3737

38-
class LayerwiseUpcastingHook(ModelHook):
38+
class LayerwiseCastingHook(ModelHook):
3939
r"""
4040
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
4141
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
@@ -55,7 +55,7 @@ def initialize_hook(self, module: torch.nn.Module):
5555

5656
def deinitalize_hook(self, module: torch.nn.Module):
5757
raise NotImplementedError(
58-
"LayerwiseUpcastingHook does not support deinitalization. A model once enabled with layerwise upcasting will "
58+
"LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
5959
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
6060
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
6161
"be re-initialized and loaded in the original dtype."
@@ -70,7 +70,7 @@ def post_forward(self, module: torch.nn.Module, output):
7070
return output
7171

7272

73-
def apply_layerwise_upcasting(
73+
def apply_layerwise_casting(
7474
module: torch.nn.Module,
7575
storage_dtype: torch.dtype,
7676
compute_dtype: torch.dtype,
@@ -79,7 +79,7 @@ def apply_layerwise_upcasting(
7979
non_blocking: bool = False,
8080
) -> None:
8181
r"""
82-
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
82+
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
8383
nn.Module using diffusers layers or pytorch primitives.
8484
8585
Example:
@@ -92,7 +92,7 @@ def apply_layerwise_upcasting(
9292
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
9393
... )
9494
95-
>>> apply_layerwise_upcasting(
95+
>>> apply_layerwise_casting(
9696
... transformer,
9797
... storage_dtype=torch.float8_e4m3fn,
9898
... compute_dtype=torch.bfloat16,
@@ -110,23 +110,23 @@ def apply_layerwise_upcasting(
110110
compute_dtype (`torch.dtype`):
111111
The dtype to cast the module to during the forward pass for computation.
112112
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
113-
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
113+
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
114114
to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
115-
alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module
115+
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
116116
instead of its internal submodules.
117117
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
118-
A list of module classes to skip during the layerwise upcasting process.
118+
A list of module classes to skip during the layerwise casting process.
119119
non_blocking (`bool`, defaults to `False`):
120120
If `True`, the weight casting operations are non-blocking.
121121
"""
122122
if skip_modules_pattern == "default":
123123
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
124124

125125
if skip_modules_classes is None and skip_modules_pattern is None:
126-
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
126+
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
127127
return
128128

129-
_apply_layerwise_upcasting(
129+
_apply_layerwise_casting(
130130
module,
131131
storage_dtype,
132132
compute_dtype,
@@ -136,7 +136,7 @@ def apply_layerwise_upcasting(
136136
)
137137

138138

139-
def _apply_layerwise_upcasting(
139+
def _apply_layerwise_casting(
140140
module: torch.nn.Module,
141141
storage_dtype: torch.dtype,
142142
compute_dtype: torch.dtype,
@@ -149,17 +149,17 @@ def _apply_layerwise_upcasting(
149149
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
150150
)
151151
if should_skip:
152-
logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
152+
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
153153
return
154154

155155
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
156-
logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
157-
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
156+
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
157+
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
158158
return
159159

160160
for name, submodule in module.named_children():
161161
layer_name = f"{_prefix}.{name}" if _prefix else name
162-
_apply_layerwise_upcasting(
162+
_apply_layerwise_casting(
163163
submodule,
164164
storage_dtype,
165165
compute_dtype,
@@ -170,11 +170,11 @@ def _apply_layerwise_upcasting(
170170
)
171171

172172

173-
def apply_layerwise_upcasting_hook(
173+
def apply_layerwise_casting_hook(
174174
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
175175
) -> None:
176176
r"""
177-
Applies a `LayerwiseUpcastingHook` to a given module.
177+
Applies a `LayerwiseCastingHook` to a given module.
178178
179179
Args:
180180
module (`torch.nn.Module`):
@@ -187,5 +187,5 @@ def apply_layerwise_upcasting_hook(
187187
If `True`, the weight casting operations are non-blocking.
188188
"""
189189
registry = HookRegistry.check_if_exists_or_initialize(module)
190-
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
191-
registry.register_hook(hook, "layerwise_upcasting")
190+
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
191+
registry.register_hook(hook, "layerwise_casting")

src/diffusers/models/modeling_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from torch import Tensor, nn
3333

3434
from .. import __version__
35-
from ..hooks import apply_layerwise_upcasting
35+
from ..hooks import apply_layerwise_casting
3636
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
3737
from ..quantizers.quantization_config import QuantizationMethod
3838
from ..utils import (
@@ -104,13 +104,13 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
104104
"""
105105
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
106106
"""
107-
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting)
107+
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
108108
if isinstance(parameter, nn.Module):
109109
for name, submodule in parameter.named_modules():
110110
if not hasattr(submodule, "_diffusers_hook"):
111111
continue
112112
registry = submodule._diffusers_hook
113-
hook = registry.get_hook("layerwise_upcasting")
113+
hook = registry.get_hook("layerwise_casting")
114114
if hook is not None:
115115
return hook.compute_dtype
116116

@@ -328,7 +328,7 @@ def disable_xformers_memory_efficient_attention(self) -> None:
328328
"""
329329
self.set_use_memory_efficient_attention_xformers(False)
330330

331-
def enable_layerwise_upcasting(
331+
def enable_layerwise_casting(
332332
self,
333333
storage_dtype: torch.dtype = torch.float8_e4m3fn,
334334
compute_dtype: Optional[torch.dtype] = None,
@@ -337,9 +337,9 @@ def enable_layerwise_upcasting(
337337
non_blocking: bool = False,
338338
) -> None:
339339
r"""
340-
Activates layerwise upcasting for the current model.
340+
Activates layerwise casting for the current model.
341341
342-
Layerwise upcasting is a technique that casts the model weights to a lower precision dtype for storage but
342+
Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
343343
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
344344
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
345345
are negligible, mostly stemming from weight casting in normalization and modulation layers.
@@ -348,10 +348,10 @@ def enable_layerwise_upcasting(
348348
embedding, positional embedding and normalization layers. This is because these layers are most likely
349349
precision-critical for quality. If you wish to change this behavior, you can set the
350350
`_skip_layerwise_casting_patterns` attribute to `None`, or call
351-
[`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments.
351+
[`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
352352
353353
Example:
354-
Using [`~models.ModelMixin.enable_layerwise_upcasting`]:
354+
Using [`~models.ModelMixin.enable_layerwise_casting`]:
355355
356356
```python
357357
>>> from diffusers import CogVideoXTransformer3DModel
@@ -360,8 +360,8 @@ def enable_layerwise_upcasting(
360360
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
361361
... )
362362
363-
>>> # Enable layerwise upcasting via the model, which ignores certain modules by default
364-
>>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
363+
>>> # Enable layerwise casting via the model, which ignores certain modules by default
364+
>>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
365365
```
366366
367367
Args:
@@ -370,18 +370,18 @@ def enable_layerwise_upcasting(
370370
compute_dtype (`torch.dtype`):
371371
The dtype to which the model weights should be cast during the forward pass.
372372
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
373-
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If
373+
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
374374
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
375375
layers.
376376
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
377-
A list of module classes to skip during the layerwise upcasting process.
377+
A list of module classes to skip during the layerwise casting process.
378378
non_blocking (`bool`, *optional*, defaults to `False`):
379379
If `True`, the weight casting operations are non-blocking.
380380
"""
381381

382382
user_provided_patterns = True
383383
if skip_modules_pattern is None:
384-
from ..hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN
384+
from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
385385

386386
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
387387
user_provided_patterns = False
@@ -393,8 +393,8 @@ def enable_layerwise_upcasting(
393393

394394
if is_peft_available() and not user_provided_patterns:
395395
# By default, we want to skip all peft layers because they have a very low memory footprint.
396-
# If users want to apply layerwise upcasting on peft layers as well, they can utilize the
397-
# `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides
396+
# If users want to apply layerwise casting on peft layers as well, they can utilize the
397+
# `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
398398
# them with more flexibility and control.
399399

400400
from peft.tuners.loha.layer import LoHaLayer
@@ -405,10 +405,10 @@ def enable_layerwise_upcasting(
405405
skip_modules_pattern += tuple(layer.adapter_layer_names)
406406

407407
if compute_dtype is None:
408-
logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.")
408+
logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
409409
compute_dtype = self.dtype
410410

411-
apply_layerwise_upcasting(
411+
apply_layerwise_casting(
412412
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
413413
)
414414

tests/lora/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,8 +2100,8 @@ def test_correct_lora_configs_with_different_ranks(self):
21002100
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
21012101
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
21022102

2103-
def test_layerwise_upcasting_inference_denoiser(self):
2104-
from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
2103+
def test_layerwise_casting_inference_denoiser(self):
2104+
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
21052105

21062106
def check_linear_dtype(module, storage_dtype, compute_dtype):
21072107
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -2142,7 +2142,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21422142
)
21432143

21442144
if storage_dtype is not None:
2145-
denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
2145+
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
21462146
check_linear_dtype(denoiser, storage_dtype, compute_dtype)
21472147

21482148
return pipe

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_set_attn_processor_for_determinism(self):
120120
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
121121
"2. Unskip this test."
122122
)
123-
def test_layerwise_upcasting_inference(self):
123+
def test_layerwise_casting_inference(self):
124124
pass
125125

126126
@unittest.skip(
@@ -129,7 +129,7 @@ def test_layerwise_upcasting_inference(self):
129129
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
130130
"2. Unskip this test."
131131
)
132-
def test_layerwise_upcasting_memory(self):
132+
def test_layerwise_casting_memory(self):
133133
pass
134134

135135

tests/models/autoencoders/test_models_autoencoder_tiny.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,15 @@ def test_effective_gradient_checkpointing(self):
178178
"1. Change the forward pass to be dtype agnostic.\n"
179179
"2. Unskip this test."
180180
)
181-
def test_layerwise_upcasting_inference(self):
181+
def test_layerwise_casting_inference(self):
182182
pass
183183

184184
@unittest.skip(
185185
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
186186
"1. Change the forward pass to be dtype agnostic.\n"
187187
"2. Unskip this test."
188188
)
189-
def test_layerwise_upcasting_memory(self):
189+
def test_layerwise_casting_memory(self):
190190
pass
191191

192192

0 commit comments

Comments
 (0)