@@ -45,6 +45,8 @@ class LayerwiseUpcastingHook(ModelHook):
4545    in the output, but can significantly reduce the memory footprint. 
4646    """ 
4747
48+     _is_stateful  =  False 
49+ 
4850    def  __init__ (self , storage_dtype : torch .dtype , compute_dtype : torch .dtype ) ->  None :
4951        self .storage_dtype  =  storage_dtype 
5052        self .compute_dtype  =  compute_dtype 
@@ -56,8 +58,8 @@ def init_hook(self, module: torch.nn.Module):
5658    def  pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
5759        module .to (dtype = self .compute_dtype )
5860        # How do we account for LongTensor, BoolTensor, etc.? 
59-         # args = tuple(align_maybe_tensor_dtype (arg, self.compute_dtype) for arg in args) 
60-         # kwargs = {k: align_maybe_tensor_dtype (v, self.compute_dtype) for k, v in kwargs.items()} 
61+         # args = tuple(_align_maybe_tensor_dtype (arg, self.compute_dtype) for arg in args) 
62+         # kwargs = {k: _align_maybe_tensor_dtype (v, self.compute_dtype) for k, v in kwargs.items()} 
6163        return  args , kwargs 
6264
6365    def  post_forward (self , module : torch .nn .Module , output ):
@@ -105,7 +107,7 @@ class LayerwiseUpcastingGranularity(str, Enum):
105107    torch .nn .Linear ,
106108]
107109
108- _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN  =  ["pos_embed" , "patch_embed" , "norm" ]
110+ _DEFAULT_SKIP_MODULES_PATTERN  =  ["pos_embed" , "patch_embed" , "norm" ]
109111# fmt: on 
110112
111113
@@ -114,9 +116,27 @@ def apply_layerwise_upcasting(
114116    storage_dtype : torch .dtype ,
115117    compute_dtype : torch .dtype ,
116118    granularity : LayerwiseUpcastingGranularity  =  LayerwiseUpcastingGranularity .PYTORCH_LAYER ,
117-     skip_modules_pattern : List [str ] =  [] ,
119+     skip_modules_pattern : List [str ] =  _DEFAULT_SKIP_MODULES_PATTERN ,
118120    skip_modules_classes : List [Type [torch .nn .Module ]] =  [],
119121) ->  torch .nn .Module :
122+     r""" 
123+     Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any 
124+     nn.Module using diffusers layers or pytorch primitives. 
125+ 
126+     Args: 
127+         module (`torch.nn.Module`): 
128+             The module to attach the hook to. 
129+         storage_dtype (`torch.dtype`): 
130+             The dtype to cast the module to before the forward pass. 
131+         compute_dtype (`torch.dtype`): 
132+             The dtype to cast the module to during the forward pass. 
133+         granularity (`LayerwiseUpcastingGranularity`, *optional*, defaults to `LayerwiseUpcastingGranularity.PYTORCH_LAYER`): 
134+             The granularity of the layerwise upcasting process. 
135+         skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): 
136+             A list of patterns to match the names of the modules to skip during the layerwise upcasting process. 
137+         skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): 
138+             A list of module classes to skip during the layerwise upcasting process. 
139+     """ 
120140    if  granularity  ==  LayerwiseUpcastingGranularity .DIFFUSERS_LAYER :
121141        return  _apply_layerwise_upcasting_diffusers_layer (
122142            module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes 
@@ -153,7 +173,7 @@ def _apply_layerwise_upcasting_diffusers_layer(
153173    module : torch .nn .Module ,
154174    storage_dtype : torch .dtype ,
155175    compute_dtype : torch .dtype ,
156-     skip_modules_pattern : List [str ] =  _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
176+     skip_modules_pattern : List [str ] =  _DEFAULT_SKIP_MODULES_PATTERN ,
157177    skip_modules_classes : List [Type [torch .nn .Module ]] =  [],
158178) ->  torch .nn .Module :
159179    for  name , submodule  in  module .named_modules ():
@@ -173,7 +193,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
173193    module : torch .nn .Module ,
174194    storage_dtype : torch .dtype ,
175195    compute_dtype : torch .dtype ,
176-     skip_modules_pattern : List [str ] =  _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
196+     skip_modules_pattern : List [str ] =  _DEFAULT_SKIP_MODULES_PATTERN ,
177197    skip_modules_classes : List [Type [torch .nn .Module ]] =  [],
178198) ->  torch .nn .Module :
179199    for  name , submodule  in  module .named_modules ():
@@ -189,7 +209,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
189209    return  module 
190210
191211
192- def  align_maybe_tensor_dtype (input : Any , dtype : torch .dtype ) ->  Any :
212+ def  _align_maybe_tensor_dtype (input : Any , dtype : torch .dtype ) ->  Any :
193213    r""" 
194214    Aligns the dtype of a tensor or a list of tensors to a given dtype. 
195215
@@ -199,14 +219,15 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
199219            types, it will be returned as is. 
200220        dtype (`torch.dtype`): 
201221            The dtype to align the tensor(s) to. 
222+ 
202223    Returns: 
203224        `Any`: 
204225            The tensor or list of tensors aligned to the given dtype. 
205226    """ 
206227    if  isinstance (input , torch .Tensor ):
207228        return  input .to (dtype = dtype )
208229    if  isinstance (input , (list , tuple )):
209-         return  [align_maybe_tensor_dtype (t , dtype ) for  t  in  input ]
230+         return  [_align_maybe_tensor_dtype (t , dtype ) for  t  in  input ]
210231    if  isinstance (input , dict ):
211-         return  {k : align_maybe_tensor_dtype (v , dtype ) for  k , v  in  input .items ()}
232+         return  {k : _align_maybe_tensor_dtype (v , dtype ) for  k , v  in  input .items ()}
212233    return  input 
0 commit comments