@@ -44,20 +44,21 @@ class LayerwiseUpcastingHook(ModelHook):
4444
4545 _is_stateful = False
4646
47- def __init__ (self , storage_dtype : torch .dtype , compute_dtype : torch .dtype ) -> None :
47+ def __init__ (self , storage_dtype : torch .dtype , compute_dtype : torch .dtype , non_blocking : bool ) -> None :
4848 self .storage_dtype = storage_dtype
4949 self .compute_dtype = compute_dtype
50+ self .non_blocking = non_blocking
5051
5152 def initialize_hook (self , module : torch .nn .Module ):
52- module .to (dtype = self .storage_dtype )
53+ module .to (dtype = self .storage_dtype , non_blocking = self . non_blocking )
5354 return module
5455
5556 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
56- module .to (dtype = self .compute_dtype )
57+ module .to (dtype = self .compute_dtype , non_blocking = self . non_blocking )
5758 return args , kwargs
5859
5960 def post_forward (self , module : torch .nn .Module , output ):
60- module .to (dtype = self .storage_dtype )
61+ module .to (dtype = self .storage_dtype , non_blocking = self . non_blocking )
6162 return output
6263
6364
@@ -67,6 +68,7 @@ def apply_layerwise_upcasting(
6768 compute_dtype : torch .dtype ,
6869 skip_modules_pattern : List [str ] = _DEFAULT_SKIP_MODULES_PATTERN ,
6970 skip_modules_classes : List [Type [torch .nn .Module ]] = [],
71+ non_blocking : bool = False ,
7072) -> torch .nn .Module :
7173 r"""
7274 Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
@@ -84,6 +86,8 @@ def apply_layerwise_upcasting(
8486 A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
8587 skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`):
8688 A list of module classes to skip during the layerwise upcasting process.
89+ non_blocking (`bool`, defaults to `False`):
90+ If `True`, the weight casting operations are non-blocking.
8791 """
8892 for name , submodule in module .named_modules ():
8993 if (
@@ -95,12 +99,12 @@ def apply_layerwise_upcasting(
9599 logger .debug (f'Skipping layerwise upcasting for layer "{ name } "' )
96100 continue
97101 logger .debug (f'Applying layerwise upcasting to layer "{ name } "' )
98- apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype )
102+ apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype , non_blocking )
99103 return module
100104
101105
102106def apply_layerwise_upcasting_hook (
103- module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype
107+ module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype , non_blocking : bool
104108) -> torch .nn .Module :
105109 r"""
106110 Applies a `LayerwiseUpcastingHook` to a given module.
@@ -112,11 +116,13 @@ def apply_layerwise_upcasting_hook(
112116 The dtype to cast the module to before the forward pass.
113117 compute_dtype (`torch.dtype`):
114118 The dtype to cast the module to during the forward pass.
119+ non_blocking (`bool`):
120+ If `True`, the weight casting operations are non-blocking.
115121
116122 Returns:
117123 `torch.nn.Module`:
118124 The same module, with the hook attached (the module is modified in place).
119125 """
120126 registry = HookRegistry .check_if_exists_or_initialize (module )
121- hook = LayerwiseUpcastingHook (storage_dtype , compute_dtype )
127+ hook = LayerwiseUpcastingHook (storage_dtype , compute_dtype , non_blocking )
122128 registry .register_hook (hook , "layerwise_upcasting" )
0 commit comments