Skip to content

Commit 3d84b9e

Browse files
committed
non_blocking
1 parent 2663026 commit 3d84b9e

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

src/diffusers/hooks/layerwise_upcasting.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,21 @@ class LayerwiseUpcastingHook(ModelHook):
4444

4545
_is_stateful = False
4646

47-
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
47+
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
4848
self.storage_dtype = storage_dtype
4949
self.compute_dtype = compute_dtype
50+
self.non_blocking = non_blocking
5051

5152
def initialize_hook(self, module: torch.nn.Module):
52-
module.to(dtype=self.storage_dtype)
53+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
5354
return module
5455

5556
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
56-
module.to(dtype=self.compute_dtype)
57+
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
5758
return args, kwargs
5859

5960
def post_forward(self, module: torch.nn.Module, output):
60-
module.to(dtype=self.storage_dtype)
61+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
6162
return output
6263

6364

@@ -67,6 +68,7 @@ def apply_layerwise_upcasting(
6768
compute_dtype: torch.dtype,
6869
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
6970
skip_modules_classes: List[Type[torch.nn.Module]] = [],
71+
non_blocking: bool = False,
7072
) -> torch.nn.Module:
7173
r"""
7274
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
@@ -84,6 +86,8 @@ def apply_layerwise_upcasting(
8486
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
8587
skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`):
8688
A list of module classes to skip during the layerwise upcasting process.
89+
non_blocking (`bool`, defaults to `False`):
90+
If `True`, the weight casting operations are non-blocking.
8791
"""
8892
for name, submodule in module.named_modules():
8993
if (
@@ -95,12 +99,12 @@ def apply_layerwise_upcasting(
9599
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
96100
continue
97101
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
98-
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
102+
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype, non_blocking)
99103
return module
100104

101105

102106
def apply_layerwise_upcasting_hook(
103-
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
107+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
104108
) -> torch.nn.Module:
105109
r"""
106110
Applies a `LayerwiseUpcastingHook` to a given module.
@@ -112,11 +116,13 @@ def apply_layerwise_upcasting_hook(
112116
The dtype to cast the module to before the forward pass.
113117
compute_dtype (`torch.dtype`):
114118
The dtype to cast the module to during the forward pass.
119+
non_blocking (`bool`):
120+
If `True`, the weight casting operations are non-blocking.
115121
116122
Returns:
117123
`torch.nn.Module`:
118124
The same module, with the hook attached (the module is modified in place).
119125
"""
120126
registry = HookRegistry.check_if_exists_or_initialize(module)
121-
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype)
127+
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
122128
registry.register_hook(hook, "layerwise_upcasting")

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def enable_layerwise_upcasting(
323323
compute_dtype: Optional[torch.dtype] = None,
324324
skip_modules_pattern: Optional[List[str]] = None,
325325
skip_modules_classes: Optional[List[Type[torch.nn.Module]]] = None,
326+
non_blocking: bool = False,
326327
) -> None:
327328
r"""
328329
Activates layerwise upcasting for the current model.
@@ -361,9 +362,12 @@ def enable_layerwise_upcasting(
361362
The dtype to which the model should be cast for storage.
362363
compute_dtype (`torch.dtype`):
363364
The dtype to which the model weights should be cast during the forward pass.
364-
granularity (`LayerwiseUpcastingGranularity`, defaults to "pytorch_layer"):
365-
The granularity of the layerwise upcasting process. Read the documentation of
366-
[`~LayerwiseUpcastingGranularity`] for more information.
365+
skip_modules_pattern (`List[str]`, *optional*):
366+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
367+
skip_modules_classes (`List[Type[torch.nn.Module]]`, *optional*):
368+
A list of module classes to skip during the layerwise upcasting process.
369+
non_blocking (`bool`, *optional*, defaults to `False`):
370+
If `True`, the weight casting operations are non-blocking.
367371
"""
368372

369373
if skip_modules_pattern is None:
@@ -389,7 +393,9 @@ def enable_layerwise_upcasting(
389393
logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.")
390394
compute_dtype = self.dtype
391395

392-
apply_layerwise_upcasting(self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
396+
apply_layerwise_upcasting(
397+
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
398+
)
393399

394400
def save_pretrained(
395401
self,

0 commit comments

Comments
 (0)