- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[core] Layerwise Upcasting #10347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    [core] Layerwise Upcasting #10347
Changes from 27 commits
      Commits
    
    
            Show all changes
          
          
            55 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      36b0c37
              
                update
              
              
                a-r-r-o-w 42046c0
              
                update
              
              
                a-r-r-o-w 7dc739b
              
                make style
              
              
                a-r-r-o-w 7ed7141
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 1fa4ee5
              
                remove dynamo disable
              
              
                a-r-r-o-w da4907e
              
                add coauthor
              
              
                a-r-r-o-w bc2ada4
              
                update
              
              
                a-r-r-o-w 91bfc3d
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 7c31bb0
              
                update
              
              
                a-r-r-o-w 8975bbf
              
                update
              
              
                a-r-r-o-w 341fbfc
              
                update mixin
              
              
                a-r-r-o-w 5f898a1
              
                add some basic tests
              
              
                a-r-r-o-w 558c64e
              
                update
              
              
                a-r-r-o-w 7858f2c
              
                update
              
              
                a-r-r-o-w 2663026
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 3d84b9e
              
                non_blocking
              
              
                a-r-r-o-w 9372647
              
                improvements
              
              
                a-r-r-o-w a0f1de7
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w e586ef3
              
                update
              
              
                a-r-r-o-w cfe6318
              
                norm.* -> norm
              
              
                a-r-r-o-w 9235f77
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 7627415
              
                apply suggestions from review
              
              
                a-r-r-o-w b9e1217
              
                add example
              
              
                a-r-r-o-w bde103c
              
                update hook implementation to the latest changes from pyramid attenti…
              
              
                a-r-r-o-w 64e6c9c
              
                deinitialize should raise an error
              
              
                a-r-r-o-w 7037133
              
                update doc page
              
              
                a-r-r-o-w f1b46d6
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 390742b
              
                Apply suggestions from code review
              
              
                a-r-r-o-w 19901e7
              
                update docs
              
              
                a-r-r-o-w 3ae32b4
              
                update
              
              
                a-r-r-o-w bf797e7
              
                refactor
              
              
                a-r-r-o-w d22465a
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 5956a9e
              
                fix _always_upcast_modules for asym ae and vq_model
              
              
                a-r-r-o-w 93bd8ee
              
                fix lumina embedding forward to not depend on weight dtype
              
              
                a-r-r-o-w 77a32a7
              
                refactor tests
              
              
                a-r-r-o-w 1335d7e
              
                add simple lora inference tests
              
              
                a-r-r-o-w a263e1a
              
                _always_upcast_modules -> _precision_sensitive_module_patterns
              
              
                a-r-r-o-w 93e36ba
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 245137f
              
                remove todo comments about review; revert changes to self.dtype in un…
              
              
                a-r-r-o-w b713511
              
                check layer dtypes in lora test
              
              
                a-r-r-o-w 4450b1c
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w ed14d26
              
                fix UNet1DModelTests::test_layerwise_upcasting_inference
              
              
                a-r-r-o-w 2c9c33f
              
                _precision_sensitive_module_patterns -> _skip_layerwise_casting_patte…
              
              
                a-r-r-o-w 08211f7
              
                skip test in NCSNppModelTests
              
              
                a-r-r-o-w 59e04c3
              
                skip tests for AutoencoderTinyTests
              
              
                a-r-r-o-w 0a16826
              
                skip tests for AutoencoderOobleckTests
              
              
                a-r-r-o-w 1d306b8
              
                skip tests for UNet1DModelTests - unsupported pytorch operations
              
              
                a-r-r-o-w a9364bd
              
                layerwise_upcasting -> layerwise_casting
              
              
                a-r-r-o-w c4d5a2b
              
                skip tests for UNetRLModelTests; needs next pytorch release for curre…
              
              
                a-r-r-o-w d175d93
              
                add layerwise fp8 pipeline test
              
              
                a-r-r-o-w bf11691
              
                use xfail
              
              
                a-r-r-o-w 1c523b2
              
                Apply suggestions from code review
              
              
                a-r-r-o-w 7803364
              
                Merge branch 'main' into layerwise-upcasting-hook
              
              
                a-r-r-o-w 376adf9
              
                add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32…
              
              
                a-r-r-o-w 719e8d3
              
                add note about memory consumption on tesla CI runner for failing test
              
              
                a-r-r-o-w File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from ..utils import is_torch_available | ||
|  | ||
|  | ||
| if is_torch_available(): | ||
| from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|  | ||
| import functools | ||
| from typing import Any, Dict, Optional, Tuple | ||
|  | ||
| import torch | ||
|  | ||
| from ..utils.logging import get_logger | ||
|  | ||
|  | ||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|  | ||
|  | ||
| class ModelHook: | ||
| r""" | ||
| A hook that contains callbacks to be executed just before and after the forward method of a model. | ||
| """ | ||
|  | ||
| _is_stateful = False | ||
|  | ||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is initialized. | ||
|  | ||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| return module | ||
|  | ||
| def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is deinitalized. | ||
|  | ||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| module.forward = module._old_forward | ||
| del module._old_forward | ||
| return module | ||
|  | ||
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | ||
| r""" | ||
| Hook that is executed just before the forward method of the model. | ||
|  | ||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass will be executed just after this event. | ||
| args (`Tuple[Any]`): | ||
| The positional arguments passed to the module. | ||
| kwargs (`Dict[Str, Any]`): | ||
| The keyword arguments passed to the module. | ||
| Returns: | ||
| `Tuple[Tuple[Any], Dict[Str, Any]]`: | ||
| A tuple with the treated `args` and `kwargs`. | ||
| """ | ||
| return args, kwargs | ||
|  | ||
| def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | ||
| r""" | ||
| Hook that is executed just after the forward method of the model. | ||
|  | ||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass been executed just before this event. | ||
| output (`Any`): | ||
| The output of the module. | ||
| Returns: | ||
| `Any`: The processed `output`. | ||
| """ | ||
| return output | ||
|  | ||
| def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when the hook is detached from a module. | ||
|  | ||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module detached from this hook. | ||
| """ | ||
| return module | ||
|  | ||
| def reset_state(self, module: torch.nn.Module): | ||
| if self._is_stateful: | ||
| raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") | ||
| return module | ||
|  | ||
|  | ||
| class HookRegistry: | ||
| def __init__(self, module_ref: torch.nn.Module) -> None: | ||
| super().__init__() | ||
|  | ||
| self.hooks: Dict[str, ModelHook] = {} | ||
|  | ||
| self._module_ref = module_ref | ||
| self._hook_order = [] | ||
|  | ||
| def register_hook(self, hook: ModelHook, name: str) -> None: | ||
| if name in self.hooks.keys(): | ||
| logger.warning(f"Hook with name {name} already exists, replacing it.") | ||
|  | ||
| if hasattr(self._module_ref, "_old_forward"): | ||
| old_forward = self._module_ref._old_forward | ||
| else: | ||
| old_forward = self._module_ref.forward | ||
| self._module_ref._old_forward = self._module_ref.forward | ||
|  | ||
| self._module_ref = hook.initialize_hook(self._module_ref) | ||
|  | ||
| if hasattr(hook, "new_forward"): | ||
| rewritten_forward = hook.new_forward | ||
|  | ||
| def new_forward(module, *args, **kwargs): | ||
| args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
| output = rewritten_forward(module, *args, **kwargs) | ||
| return hook.post_forward(module, output) | ||
| else: | ||
|  | ||
| def new_forward(module, *args, **kwargs): | ||
| args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
| output = old_forward(*args, **kwargs) | ||
| return hook.post_forward(module, output) | ||
|  | ||
| self._module_ref.forward = functools.update_wrapper( | ||
| functools.partial(new_forward, self._module_ref), old_forward | ||
| ) | ||
|  | ||
| self.hooks[name] = hook | ||
| self._hook_order.append(name) | ||
|  | ||
| def get_hook(self, name: str) -> Optional[ModelHook]: | ||
| if name not in self.hooks.keys(): | ||
| return None | ||
| return self.hooks[name] | ||
|  | ||
| def remove_hook(self, name: str, recurse: bool = True) -> None: | ||
| if name in self.hooks.keys(): | ||
| hook = self.hooks[name] | ||
| self._module_ref = hook.deinitalize_hook(self._module_ref) | ||
| del self.hooks[name] | ||
| self._hook_order.remove(name) | ||
|  | ||
| if recurse: | ||
| for module_name, module in self._module_ref.named_modules(): | ||
| if module_name == "": | ||
| continue | ||
| if hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook.remove_hook(name, recurse=False) | ||
|  | ||
| def reset_stateful_hooks(self, recurse: bool = True) -> None: | ||
| for hook_name in self._hook_order: | ||
| hook = self.hooks[hook_name] | ||
| if hook._is_stateful: | ||
| hook.reset_state(self._module_ref) | ||
|  | ||
| if recurse: | ||
| for module_name, module in self._module_ref.named_modules(): | ||
| if module_name == "": | ||
| continue | ||
| if hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook.reset_stateful_hooks(recurse=False) | ||
|  | ||
| @classmethod | ||
| def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": | ||
| if not hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook = cls(module) | ||
| return module._diffusers_hook | ||
|  | ||
| def __repr__(self) -> str: | ||
| hook_repr = "" | ||
| for i, hook_name in enumerate(self._hook_order): | ||
| hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" | ||
| if i < len(self._hook_order) - 1: | ||
| hook_repr += "\n" | ||
| return f"HookRegistry(\n{hook_repr}\n)" | 
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.