1313# limitations under the License.
1414
1515import functools
16- import re
17- from enum import Enum
18- from typing import Any , Dict , List , Tuple , Type
16+ from typing import Any , Dict , Tuple
1917
2018import torch
2119
2220from ..utils import get_logger
23- from .attention import FeedForward , LuminaFeedForward
24- from .embeddings import (
25- AttentionPooling ,
26- CogVideoXPatchEmbed ,
27- CogView3PlusPatchEmbed ,
28- GLIGENTextBoundingboxProjection ,
29- HunyuanDiTAttentionPool ,
30- LuminaPatchEmbed ,
31- MochiAttentionPool ,
32- PixArtAlphaTextProjection ,
33- TimestepEmbedding ,
34- )
3521
3622
3723logger = get_logger (__name__ ) # pylint: disable=invalid-name
@@ -44,6 +30,8 @@ class ModelHook:
4430 with PyTorch existing hooks is that they get passed along the kwargs.
4531 """
4632
33+ _is_stateful = False
34+
4735 def init_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
4836 r"""
4937 Hook that is executed when a model is initialized.
@@ -95,6 +83,11 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
9583 """
9684 return module
9785
86+ def reset_state (self , module : torch .nn .Module ):
87+ if self ._is_stateful :
88+ raise NotImplementedError ("This hook is stateful and needs to implement the `reset_state` method." )
89+ return module
90+
9891
9992class SequentialHook (ModelHook ):
10093 r"""A hook that can contain several hooks and iterates through them at each event."""
@@ -122,34 +115,12 @@ def detach_hook(self, module):
122115 module = hook .detach_hook (module )
123116 return module
124117
125-
126- class LayerwiseUpcastingHook (ModelHook ):
127- r"""
128- A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast
129- the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a
130- lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss
131- in the output, but can significantly reduce the memory footprint.
132- """
133-
134- def __init__ (self , storage_dtype : torch .dtype , compute_dtype : torch .dtype ) -> None :
135- self .storage_dtype = storage_dtype
136- self .compute_dtype = compute_dtype
137-
138- def init_hook (self , module : torch .nn .Module ):
139- module .to (dtype = self .storage_dtype )
118+ def reset_state (self , module ):
119+ for hook in self .hooks :
120+ if hook ._is_stateful :
121+ hook .reset_state (module )
140122 return module
141123
142- def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
143- module .to (dtype = self .compute_dtype )
144- # How do we account for LongTensor, BoolTensor, etc.?
145- # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args)
146- # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()}
147- return args , kwargs
148-
149- def post_forward (self , module : torch .nn .Module , output ):
150- module .to (dtype = self .storage_dtype )
151- return output
152-
153124
154125def add_hook_to_module (module : torch .nn .Module , hook : ModelHook , append : bool = False ):
155126 r"""
@@ -244,148 +215,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
244215 return module
245216
246217
247- def align_maybe_tensor_dtype (input : Any , dtype : torch .dtype ) -> Any :
248- r"""
249- Aligns the dtype of a tensor or a list of tensors to a given dtype.
250-
251- Args:
252- input (`Any`):
253- The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these
254- types, it will be returned as is.
255- dtype (`torch.dtype`):
256- The dtype to align the tensor(s) to.
257- Returns:
258- `Any`:
259- The tensor or list of tensors aligned to the given dtype.
260- """
261- if isinstance (input , torch .Tensor ):
262- return input .to (dtype = dtype )
263- if isinstance (input , (list , tuple )):
264- return [align_maybe_tensor_dtype (t , dtype ) for t in input ]
265- if isinstance (input , dict ):
266- return {k : align_maybe_tensor_dtype (v , dtype ) for k , v in input .items ()}
267- return input
268-
269-
270- class LayerwiseUpcastingGranularity (str , Enum ):
271- r"""
272- An enumeration class that defines the granularity of the layerwise upcasting process.
273-
274- Granularity can be one of the following:
275- - `DIFFUSERS_LAYER`:
276- Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to
277- only those layers that are a group of linear layers, while excluding precision-critical layers like
278- modulation and normalization layers.
279- - `PYTORCH_LAYER`:
280- Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular
281- level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while
282- also ensuring important operations like normalization with learned parameters remain unaffected from the
283- downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory
284- footprint for storing the model may be slightly higher than the alternatives. This method causes the
285- highest number of casting operations, which may contribute to a slight increase in the overall computation
286- time.
287-
288- Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to
289- lower precision, as this may lead to significant quality loss.
218+ def reset_stateful_hooks (module : torch .nn .Module , recurse : bool = False ):
290219 """
291-
292- DIFFUSERS_LAYER = "diffusers_layer"
293- PYTORCH_LAYER = "pytorch_layer"
294-
295-
296- # fmt: off
297- _SUPPORTED_DIFFUSERS_LAYERS = [
298- AttentionPooling , MochiAttentionPool , HunyuanDiTAttentionPool ,
299- CogVideoXPatchEmbed , CogView3PlusPatchEmbed , LuminaPatchEmbed ,
300- TimestepEmbedding , GLIGENTextBoundingboxProjection , PixArtAlphaTextProjection ,
301- FeedForward , LuminaFeedForward ,
302- ]
303-
304- _SUPPORTED_PYTORCH_LAYERS = [
305- torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
306- torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
307- torch .nn .Linear ,
308- ]
309-
310- _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed" , "patch_embed" , "norm" ]
311- # fmt: on
312-
313-
314- def apply_layerwise_upcasting_hook (
315- module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype
316- ) -> torch .nn .Module :
317- r"""
318- Applies a `LayerwiseUpcastingHook` to a given module.
220+ Resets the state of all stateful hooks attached to a module.
319221
320222 Args:
321223 module (`torch.nn.Module`):
322- The module to attach the hook to.
323- storage_dtype (`torch.dtype`):
324- The dtype to cast the module to before the forward pass.
325- compute_dtype (`torch.dtype`):
326- The dtype to cast the module to during the forward pass.
327-
328- Returns:
329- `torch.nn.Module`:
330- The same module, with the hook attached (the module is modified in place, so the result can be discarded).
224+ The module to reset the stateful hooks from.
331225 """
332- hook = LayerwiseUpcastingHook (storage_dtype , compute_dtype )
333- return add_hook_to_module (module , hook , append = True )
334-
335-
336- def apply_layerwise_upcasting (
337- module : torch .nn .Module ,
338- storage_dtype : torch .dtype ,
339- compute_dtype : torch .dtype ,
340- granularity : LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity .PYTORCH_LAYER ,
341- skip_modules_pattern : List [str ] = [],
342- skip_modules_classes : List [Type [torch .nn .Module ]] = [],
343- ) -> torch .nn .Module :
344- if granularity == LayerwiseUpcastingGranularity .DIFFUSERS_LAYER :
345- return _apply_layerwise_upcasting_diffusers_layer (
346- module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes
347- )
348- if granularity == LayerwiseUpcastingGranularity .PYTORCH_LAYER :
349- return _apply_layerwise_upcasting_pytorch_layer (
350- module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes
351- )
352-
353-
354- def _apply_layerwise_upcasting_diffusers_layer (
355- module : torch .nn .Module ,
356- storage_dtype : torch .dtype ,
357- compute_dtype : torch .dtype ,
358- skip_modules_pattern : List [str ] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
359- skip_modules_classes : List [Type [torch .nn .Module ]] = [],
360- ) -> torch .nn .Module :
361- for name , submodule in module .named_modules ():
362- if (
363- any (re .search (pattern , name ) for pattern in skip_modules_pattern )
364- or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
365- or not isinstance (submodule , tuple (_SUPPORTED_DIFFUSERS_LAYERS ))
366- ):
367- logger .debug (f'Skipping layerwise upcasting for layer "{ name } "' )
368- continue
369- logger .debug (f'Applying layerwise upcasting to layer "{ name } "' )
370- apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype )
371- return module
226+ if hasattr (module , "_diffusers_hook" ) and (
227+ module ._diffusers_hook ._is_stateful or isinstance (module ._diffusers_hook , SequentialHook )
228+ ):
229+ module ._diffusers_hook .reset_state (module )
372230
373-
374- def _apply_layerwise_upcasting_pytorch_layer (
375- module : torch .nn .Module ,
376- storage_dtype : torch .dtype ,
377- compute_dtype : torch .dtype ,
378- skip_modules_pattern : List [str ] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
379- skip_modules_classes : List [Type [torch .nn .Module ]] = [],
380- ) -> torch .nn .Module :
381- for name , submodule in module .named_modules ():
382- if (
383- any (re .search (pattern , name ) for pattern in skip_modules_pattern )
384- or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
385- or not isinstance (submodule , tuple (_SUPPORTED_PYTORCH_LAYERS ))
386- ):
387- logger .debug (f'Skipping layerwise upcasting for layer "{ name } "' )
388- continue
389- logger .debug (f'Applying layerwise upcasting to layer "{ name } "' )
390- apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype )
391- return module
231+ if recurse :
232+ for child in module .children ():
233+ reset_stateful_hooks (child , recurse )
0 commit comments