1313# limitations under the License. 
1414
1515import  re 
16- from  typing  import  Optional , Tuple , Type 
16+ from  typing  import  Optional , Tuple , Type ,  Union 
1717
1818import  torch 
1919
2525
2626
2727# fmt: off 
28- _SUPPORTED_PYTORCH_LAYERS  =  (
28+ SUPPORTED_PYTORCH_LAYERS  =  (
2929    torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
3030    torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
3131    torch .nn .Linear ,
3232)
3333
34- _DEFAULT_SKIP_MODULES_PATTERN  =  ("pos_embed" , "patch_embed" , "norm" )
34+ DEFAULT_SKIP_MODULES_PATTERN  =  ("pos_embed" , "patch_embed" , "norm"  ,  "^proj_in$" ,  "^proj_out$ "
3535# fmt: on 
3636
3737
@@ -74,8 +74,8 @@ def apply_layerwise_upcasting(
7474    module : torch .nn .Module ,
7575    storage_dtype : torch .dtype ,
7676    compute_dtype : torch .dtype ,
77-     skip_modules_pattern : Optional [ Tuple [str ]] =  _DEFAULT_SKIP_MODULES_PATTERN ,
78-     skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ]]] =  None ,
77+     skip_modules_pattern : Union [ str ,  Tuple [str , ... ]] =  "default" ,
78+     skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ... ]] =  None ,
7979    non_blocking : bool  =  False ,
8080    _prefix : str  =  "" ,
8181) ->  None :
@@ -87,13 +87,14 @@ def apply_layerwise_upcasting(
8787
8888    ```python 
8989    >>> import torch 
90-     >>> from diffusers import CogVideoXPipeline, apply_layerwise_upcasting  
90+     >>> from diffusers import CogVideoXTransformer3DModel  
9191
92-     >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) 
93-     >>> pipe.to("cuda") 
92+     >>> transformer = CogVideoXTransformer3DModel.from_pretrained( 
93+     ...     model_id, subfolder="transformer", torch_dtype=torch.bfloat16 
94+     ... ) 
9495
9596    >>> apply_layerwise_upcasting( 
96-     ...     pipe. transformer, 
97+     ...     transformer, 
9798    ...     storage_dtype=torch.float8_e4m3fn, 
9899    ...     compute_dtype=torch.bfloat16, 
99100    ...     skip_modules_pattern=["patch_embed", "norm"], 
@@ -109,13 +110,17 @@ def apply_layerwise_upcasting(
109110            The dtype to cast the module to before/after the forward pass for storage. 
110111        compute_dtype (`torch.dtype`): 
111112            The dtype to cast the module to during the forward pass for computation. 
112-         skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): 
113-             A list of patterns to match the names of the modules to skip during the layerwise upcasting process. 
114-         skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): 
113+         skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): 
114+             A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set 
115+             to `"default"`, the default patterns are used. 
116+         skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): 
115117            A list of module classes to skip during the layerwise upcasting process. 
116118        non_blocking (`bool`, defaults to `False`): 
117119            If `True`, the weight casting operations are non-blocking. 
118120    """ 
121+     if  skip_modules_pattern  ==  "default" :
122+         skip_modules_pattern  =  DEFAULT_SKIP_MODULES_PATTERN 
123+ 
119124    if  skip_modules_classes  is  None  and  skip_modules_pattern  is  None :
120125        apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
121126        return 
@@ -127,7 +132,7 @@ def apply_layerwise_upcasting(
127132        logger .debug (f'Skipping layerwise upcasting for layer "{ _prefix }  )
128133        return 
129134
130-     if  isinstance (module , _SUPPORTED_PYTORCH_LAYERS ):
135+     if  isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
131136        logger .debug (f'Applying layerwise upcasting to layer "{ _prefix }  )
132137        apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
133138        return 
0 commit comments