-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Module Group Offloading #10503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Module Group Offloading #10503
Changes from 4 commits
d1737e3
2783669
6a9a3e5
c426a34
d579037
5f33621
a8eabd0
deda9a3
80ac5a7
d2a2981
01c7d22
22aff34
42bc19b
8c63bf5
e09e716
bf379c1
0bf0baf
b850c75
6ed9c2f
13dd337
073d4bc
8ba2bda
b2e838f
f30c55f
5ea3d8a
db2fd3b
a0160e1
aaa9a53
17b2753
edf8103
af62c93
f227e15
24f9273
8f10d05
06b411f
8bd7e3b
904e470
3172ed5
72aa57f
aee24bc
db125ce
3f20e6b
840576a
8804d74
954bb7d
ba6c4a8
da88c33
a872e84
6be43b8
274b84e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from ..utils import is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| from .group_offloading import apply_group_offloading |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import re | ||
| from typing import List, Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
||
| _COMMON_STACK_IDENTIFIERS = { | ||
| "transformer_blocks", | ||
| "single_transformer_blocks", | ||
| "temporal_transformer_blocks", | ||
| "transformer_layers", | ||
| "layers", | ||
| "blocks", | ||
| "down_blocks", | ||
| "up_blocks", | ||
| "mid_blocks", | ||
| } | ||
|
|
||
|
|
||
| class ModuleGroup: | ||
| def __init__( | ||
| self, | ||
| modules: List[torch.nn.Module], | ||
| offload_device: torch.device, | ||
| onload_device: torch.device, | ||
| offload_leader: torch.nn.Module, | ||
| onload_leader: Optional[torch.nn.Module] = None, | ||
| ) -> None: | ||
| self.modules = modules | ||
| self.offload_device = offload_device | ||
| self.onload_device = onload_device | ||
| self.offload_leader = offload_leader | ||
| self.onload_leader = onload_leader | ||
|
|
||
|
|
||
| class GroupOffloadingHook(ModelHook): | ||
| r""" | ||
| A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for | ||
| computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" | ||
| module that is responsible for offloading. | ||
|
|
||
| This implementation assumes the following: | ||
| - For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom | ||
| user-provided regex pattern, the module that triggers its forward pass first is considered the leader. | ||
| - The inputs are already on the correct device. This is expected because the hook does not modify the state of | ||
| inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs | ||
| not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is | ||
| not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you | ||
| encounter such an error. | ||
| """ | ||
|
|
||
| def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None: | ||
| self.group = group | ||
| self.offload_on_init = offload_on_init | ||
| self.non_blocking = non_blocking | ||
|
|
||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| if self.offload_on_init: | ||
| self.offload_(module) | ||
| return module | ||
|
|
||
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | ||
| if self.group.onload_leader is None: | ||
| self.group.onload_leader = module | ||
| self.onload_(module) | ||
| return args, kwargs | ||
|
|
||
| def post_forward(self, module: torch.nn.Module, output): | ||
| self.offload_(module) | ||
| return output | ||
|
|
||
| def onload_(self, module: torch.nn.Module) -> None: | ||
| if self.group.onload_leader == module: | ||
| for group_module in self.group.modules: | ||
| group_module.to(self.group.onload_device, non_blocking=self.non_blocking) | ||
|
|
||
| def offload_(self, module: torch.nn.Module) -> None: | ||
| if self.group.offload_leader == module: | ||
| for group_module in self.group.modules: | ||
| group_module.to(self.group.offload_device, non_blocking=self.non_blocking) | ||
| # TODO: do we need to sync here because of GPU->CPU transfer? | ||
| if self.non_blocking and self.group.offload_device.type == "cpu": | ||
| torch.cpu.synchronize() | ||
|
|
||
|
|
||
| def apply_group_offloading( | ||
| module: torch.nn.Module, | ||
| offload_group_patterns: Union[str, List[str]] = "diffusers_block", | ||
| num_blocks_per_group: Optional[int] = None, | ||
| offload_device: torch.device = torch.device("cpu"), | ||
| onload_device: torch.device = torch.device("cuda"), | ||
| force_offload: bool = True, | ||
| non_blocking: bool = False, | ||
| ) -> None: | ||
| if offload_group_patterns == "diffusers_block": | ||
| if num_blocks_per_group is None: | ||
| raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") | ||
| _apply_group_offloading_diffusers_block( | ||
| module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking | ||
| ) | ||
| else: | ||
| _apply_group_offloading_group_patterns( | ||
| module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking | ||
| ) | ||
|
|
||
|
|
||
| def _apply_group_offloading_diffusers_block( | ||
| module: torch.nn.Module, | ||
| num_blocks_per_group: int, | ||
| offload_device: torch.device, | ||
| onload_device: torch.device, | ||
| force_offload: bool, | ||
| non_blocking: bool, | ||
| ) -> None: | ||
| # Handle device offloading/onloading for unet/transformer stack modules | ||
| for stack_identifier in _COMMON_STACK_IDENTIFIERS: | ||
| if not hasattr(module, stack_identifier) or not isinstance( | ||
| getattr(module, stack_identifier), torch.nn.ModuleList | ||
| ): | ||
| continue | ||
|
|
||
| stack = getattr(module, stack_identifier) | ||
| num_blocks = len(stack) | ||
|
|
||
| for i in range(0, num_blocks, num_blocks_per_group): | ||
| blocks = stack[i : i + num_blocks_per_group] | ||
| group = ModuleGroup( | ||
| blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] | ||
| ) | ||
| should_offload = force_offload or i > 0 | ||
| _apply_group_offloading(group, should_offload, non_blocking) | ||
|
|
||
| # Handle device offloading/onloading for non-stack modules | ||
| for name, submodule in module.named_modules(): | ||
| name_split = name.split(".") | ||
| if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1: | ||
| # We only want the layers that are top-level in the module (encompass all the submodules) | ||
| # for enabling offloading. | ||
| continue | ||
| layer_name = name_split[0] | ||
| print(layer_name) | ||
| if layer_name in _COMMON_STACK_IDENTIFIERS: | ||
| continue | ||
| group = ModuleGroup( | ||
| [submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule | ||
| ) | ||
| _apply_group_offloading(group, force_offload, non_blocking) | ||
|
|
||
| # Always keep parameters and buffers on onload_device | ||
| for name, param in module.named_parameters(recurse=False): | ||
| if torch.is_tensor(param.data): | ||
| param.data = param.data.to(onload_device) | ||
| for name, buffer in module.named_buffers(recurse=False): | ||
| if torch.is_tensor(buffer.data): | ||
| buffer.data = buffer.data.to(onload_device) | ||
|
|
||
|
|
||
| def _apply_group_offloading_group_patterns( | ||
|
||
| module: torch.nn.Module, | ||
| offload_group_patterns: List[str], | ||
| offload_device: torch.device, | ||
| onload_device: torch.device, | ||
| force_offload: bool, | ||
| non_blocking: bool, | ||
| ) -> None: | ||
| per_group_modules = [] | ||
| for i, offload_group_pattern in enumerate(offload_group_patterns): | ||
| group_modules = [] | ||
| group_module_names = [] | ||
| for name, module in module.named_modules(): | ||
| if re.search(offload_group_pattern, name) is not None: | ||
| group_modules.append(module) | ||
| group_module_names.append(name) | ||
| per_group_modules.append( | ||
| { | ||
| "modules": group_modules, | ||
| "module_names": group_module_names, | ||
| } | ||
| ) | ||
|
|
||
| # Check if there are any overlapping modules between groups | ||
| for i, group in enumerate(per_group_modules): | ||
| for j, other_group in enumerate(per_group_modules): | ||
| if j <= i: | ||
| continue | ||
| if any(module_name in group["module_names"] for module_name in other_group["module_names"]): | ||
| raise ValueError( | ||
| f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive." | ||
| ) | ||
|
|
||
| # Apply offloading to each group | ||
| for group in per_group_modules: | ||
| # TODO: handle offload leader correctly | ||
| group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1]) | ||
| _apply_group_offloading(group, force_offload, non_blocking) | ||
|
|
||
|
|
||
| def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None: | ||
| for module in group.modules: | ||
| hook = GroupOffloadingHook(group, offload_on_init, non_blocking) | ||
| registry = HookRegistry.check_if_exists_or_initialize(module) | ||
| registry.register_hook(hook, "group_offloading") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import functools | ||
| from typing import Any, Dict, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from ..utils.logging import get_logger | ||
|
|
||
|
|
||
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| class ModelHook: | ||
| r""" | ||
| A hook that contains callbacks to be executed just before and after the forward method of a model. | ||
| """ | ||
|
|
||
| _is_stateful = False | ||
|
|
||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is initialized. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| return module | ||
|
|
||
| def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when a model is deinitalized. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module attached to this hook. | ||
| """ | ||
| module.forward = module._old_forward | ||
| del module._old_forward | ||
| return module | ||
|
|
||
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | ||
| r""" | ||
| Hook that is executed just before the forward method of the model. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass will be executed just after this event. | ||
| args (`Tuple[Any]`): | ||
| The positional arguments passed to the module. | ||
| kwargs (`Dict[Str, Any]`): | ||
| The keyword arguments passed to the module. | ||
|
|
||
| Returns: | ||
| `Tuple[Tuple[Any], Dict[Str, Any]]`: | ||
| A tuple with the treated `args` and `kwargs`. | ||
| """ | ||
| return args, kwargs | ||
|
|
||
| def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | ||
| r""" | ||
| Hook that is executed just after the forward method of the model. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module whose forward pass been executed just before this event. | ||
| output (`Any`): | ||
| The output of the module. | ||
|
|
||
| Returns: | ||
| `Any`: The processed `output`. | ||
| """ | ||
| return output | ||
|
|
||
| def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| r""" | ||
| Hook that is executed when the hook is detached from a module. | ||
|
|
||
| Args: | ||
| module (`torch.nn.Module`): | ||
| The module detached from this hook. | ||
| """ | ||
| return module | ||
|
|
||
| def reset_state(self, module: torch.nn.Module): | ||
| if self._is_stateful: | ||
| raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") | ||
| return module | ||
|
|
||
|
|
||
| class HookRegistry: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good 👍🏽 |
||
| def __init__(self, module_ref: torch.nn.Module) -> None: | ||
| super().__init__() | ||
|
|
||
| self.hooks: Dict[str, ModelHook] = {} | ||
|
|
||
| self._module_ref = module_ref | ||
| self._hook_order = [] | ||
|
|
||
| def register_hook(self, hook: ModelHook, name: str) -> None: | ||
| if name in self.hooks.keys(): | ||
| logger.warning(f"Hook with name {name} already exists, replacing it.") | ||
|
|
||
| if hasattr(self._module_ref, "_old_forward"): | ||
| old_forward = self._module_ref._old_forward | ||
| else: | ||
| old_forward = self._module_ref.forward | ||
| self._module_ref._old_forward = self._module_ref.forward | ||
|
|
||
| self._module_ref = hook.initialize_hook(self._module_ref) | ||
|
|
||
| if hasattr(hook, "new_forward"): | ||
| new_forward = hook.new_forward | ||
| else: | ||
|
|
||
| def new_forward(module, *args, **kwargs): | ||
| args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
| output = old_forward(*args, **kwargs) | ||
| return hook.post_forward(module, output) | ||
|
|
||
| new_forward = functools.update_wrapper(new_forward, old_forward) | ||
| self._module_ref.forward = new_forward.__get__(self._module_ref) | ||
|
|
||
| self.hooks[name] = hook | ||
| self._hook_order.append(name) | ||
|
|
||
| def get_hook(self, name: str) -> ModelHook: | ||
| if name not in self.hooks.keys(): | ||
| raise ValueError(f"Hook with name {name} not found.") | ||
| return self.hooks[name] | ||
|
|
||
| def remove_hook(self, name: str) -> None: | ||
| if name not in self.hooks.keys(): | ||
| raise ValueError(f"Hook with name {name} not found.") | ||
| self.hooks[name].deinitalize_hook(self._module_ref) | ||
| del self.hooks[name] | ||
| self._hook_order.remove(name) | ||
|
|
||
| @classmethod | ||
| def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": | ||
| if not hasattr(module, "_diffusers_hook"): | ||
| module._diffusers_hook = cls(module) | ||
| return module._diffusers_hook | ||
|
|
||
| def __repr__(self) -> str: | ||
| hook_repr = "" | ||
| for i, hook_name in enumerate(self._hook_order): | ||
| hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" | ||
| if i < len(self._hook_order) - 1: | ||
| hook_repr += "\n" | ||
| return f"HookRegistry(\n{hook_repr}\n)" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to have this as an attribute within each model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually will remove this completely. This should be applicable on any model containing ModuleList or Sequential because we know for sure, atleast in Diffusers, that the call order of these layers are sequential and not in some weird access pattern.
So, will make the check to just look for the above two classes with isinstance