@@ -77,7 +77,6 @@ def apply_layerwise_upcasting(
7777    skip_modules_pattern : Union [str , Tuple [str , ...]] =  "default" ,
7878    skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ...]] =  None ,
7979    non_blocking : bool  =  False ,
80-     _prefix : str  =  "" ,
8180) ->  None :
8281    r""" 
8382    Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any 
@@ -97,7 +96,7 @@ def apply_layerwise_upcasting(
9796    ...     transformer, 
9897    ...     storage_dtype=torch.float8_e4m3fn, 
9998    ...     compute_dtype=torch.bfloat16, 
100-     ...     skip_modules_pattern=["patch_embed", "norm"], 
99+     ...     skip_modules_pattern=["patch_embed", "norm", "proj_out" ], 
101100    ...     non_blocking=True, 
102101    ... ) 
103102    ``` 
@@ -112,7 +111,9 @@ def apply_layerwise_upcasting(
112111            The dtype to cast the module to during the forward pass for computation. 
113112        skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): 
114113            A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set 
115-             to `"default"`, the default patterns are used. 
114+             to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` 
115+             alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module 
116+             instead of its internal submodules. 
116117        skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): 
117118            A list of module classes to skip during the layerwise upcasting process. 
118119        non_blocking (`bool`, defaults to `False`): 
@@ -125,6 +126,25 @@ def apply_layerwise_upcasting(
125126        apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
126127        return 
127128
129+     _apply_layerwise_upcasting (
130+         module ,
131+         storage_dtype ,
132+         compute_dtype ,
133+         skip_modules_pattern ,
134+         skip_modules_classes ,
135+         non_blocking ,
136+     )
137+ 
138+ 
139+ def  _apply_layerwise_upcasting (
140+     module : torch .nn .Module ,
141+     storage_dtype : torch .dtype ,
142+     compute_dtype : torch .dtype ,
143+     skip_modules_pattern : Optional [Tuple [str , ...]] =  None ,
144+     skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ...]] =  None ,
145+     non_blocking : bool  =  False ,
146+     _prefix : str  =  "" ,
147+ ) ->  None :
128148    should_skip  =  (skip_modules_classes  is  not None  and  isinstance (module , skip_modules_classes )) or  (
129149        skip_modules_pattern  is  not None  and  any (re .search (pattern , _prefix ) for  pattern  in  skip_modules_pattern )
130150    )
@@ -139,7 +159,7 @@ def apply_layerwise_upcasting(
139159
140160    for  name , submodule  in  module .named_children ():
141161        layer_name  =  f"{ _prefix } { name }   if  _prefix  else  name 
142-         apply_layerwise_upcasting (
162+         _apply_layerwise_upcasting (
143163            submodule ,
144164            storage_dtype ,
145165            compute_dtype ,
0 commit comments