Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d1737e3
update
a-r-r-o-w Jan 9, 2025
2783669
fix
a-r-r-o-w Jan 9, 2025
6a9a3e5
non_blocking; handle parameters and buffers
a-r-r-o-w Jan 10, 2025
c426a34
update
a-r-r-o-w Jan 10, 2025
d579037
Group offloading with cuda stream prefetching (#10516)
a-r-r-o-w Jan 11, 2025
5f33621
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 11, 2025
a8eabd0
update
a-r-r-o-w Jan 12, 2025
deda9a3
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 16, 2025
80ac5a7
copy model hook implementation from pab
a-r-r-o-w Jan 16, 2025
d2a2981
update; ~very workaround based implementation but it seems to work as…
a-r-r-o-w Jan 16, 2025
01c7d22
more workarounds to make it actually work
a-r-r-o-w Jan 16, 2025
22aff34
cleanup
a-r-r-o-w Jan 16, 2025
42bc19b
rewrite
a-r-r-o-w Jan 17, 2025
8c63bf5
update
a-r-r-o-w Jan 19, 2025
e09e716
make sure to sync current stream before overwriting with pinned params
a-r-r-o-w Jan 19, 2025
bf379c1
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 19, 2025
0bf0baf
better check
a-r-r-o-w Jan 19, 2025
b850c75
update
a-r-r-o-w Jan 20, 2025
6ed9c2f
remove hook implementation to not deal with merge conflict
a-r-r-o-w Jan 23, 2025
13dd337
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 23, 2025
073d4bc
re-add hook changes
a-r-r-o-w Jan 23, 2025
8ba2bda
why use more memory when less memory do trick
a-r-r-o-w Jan 23, 2025
b2e838f
why still use slightly more memory when less memory do trick
a-r-r-o-w Jan 23, 2025
f30c55f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 23, 2025
5ea3d8a
optimise
a-r-r-o-w Jan 26, 2025
db2fd3b
add model tests
a-r-r-o-w Jan 26, 2025
a0160e1
add pipeline tests
a-r-r-o-w Jan 26, 2025
aaa9a53
update docs
a-r-r-o-w Jan 26, 2025
17b2753
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 26, 2025
edf8103
add layernorm and groupnorm
a-r-r-o-w Jan 26, 2025
af62c93
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 28, 2025
f227e15
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 4, 2025
24f9273
address review comments
a-r-r-o-w Feb 4, 2025
8f10d05
improve tests; add docs
a-r-r-o-w Feb 4, 2025
06b411f
improve docs
a-r-r-o-w Feb 4, 2025
8bd7e3b
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 4, 2025
904e470
Apply suggestions from code review
a-r-r-o-w Feb 5, 2025
3172ed5
apply suggestions from code review
a-r-r-o-w Feb 5, 2025
72aa57f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 5, 2025
aee24bc
update tests
a-r-r-o-w Feb 5, 2025
db125ce
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 6, 2025
3f20e6b
apply suggestions from review
a-r-r-o-w Feb 6, 2025
840576a
enable_group_offloading -> enable_group_offload for naming consistency
a-r-r-o-w Feb 6, 2025
8804d74
raise errors if multiple offloading strategies used; add relevant tests
a-r-r-o-w Feb 6, 2025
954bb7d
handle .to() when group offload applied
a-r-r-o-w Feb 6, 2025
ba6c4a8
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 6, 2025
da88c33
refactor some repeated code
a-r-r-o-w Feb 6, 2025
a872e84
remove unintentional change from merge conflict
a-r-r-o-w Feb 6, 2025
6be43b8
handle .cuda()
a-r-r-o-w Feb 6, 2025
274b84e
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..utils import is_torch_available


if is_torch_available():
from .group_offloading import apply_group_offloading
218 changes: 218 additions & 0 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import List, Optional, Union

import torch

from .hooks import HookRegistry, ModelHook


_COMMON_STACK_IDENTIFIERS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to have this as an attribute within each model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually will remove this completely. This should be applicable on any model containing ModuleList or Sequential because we know for sure, atleast in Diffusers, that the call order of these layers are sequential and not in some weird access pattern.

So, will make the check to just look for the above two classes with isinstance

"transformer_blocks",
"single_transformer_blocks",
"temporal_transformer_blocks",
"transformer_layers",
"layers",
"blocks",
"down_blocks",
"up_blocks",
"mid_blocks",
}


class ModuleGroup:
def __init__(
self,
modules: List[torch.nn.Module],
offload_device: torch.device,
onload_device: torch.device,
offload_leader: torch.nn.Module,
onload_leader: Optional[torch.nn.Module] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader


class GroupOffloadingHook(ModelHook):
r"""
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
module that is responsible for offloading.

This implementation assumes the following:
- For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom
user-provided regex pattern, the module that triggers its forward pass first is considered the leader.
- The inputs are already on the correct device. This is expected because the hook does not modify the state of
inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs
not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is
not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you
encounter such an error.
"""

def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None:
self.group = group
self.offload_on_init = offload_on_init
self.non_blocking = non_blocking

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.offload_on_init:
self.offload_(module)
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader is None:
self.group.onload_leader = module
self.onload_(module)
return args, kwargs

def post_forward(self, module: torch.nn.Module, output):
self.offload_(module)
return output

def onload_(self, module: torch.nn.Module) -> None:
if self.group.onload_leader == module:
for group_module in self.group.modules:
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)

def offload_(self, module: torch.nn.Module) -> None:
if self.group.offload_leader == module:
for group_module in self.group.modules:
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
# TODO: do we need to sync here because of GPU->CPU transfer?
if self.non_blocking and self.group.offload_device.type == "cpu":
torch.cpu.synchronize()


def apply_group_offloading(
module: torch.nn.Module,
offload_group_patterns: Union[str, List[str]] = "diffusers_block",
num_blocks_per_group: Optional[int] = None,
offload_device: torch.device = torch.device("cpu"),
onload_device: torch.device = torch.device("cuda"),
force_offload: bool = True,
non_blocking: bool = False,
) -> None:
if offload_group_patterns == "diffusers_block":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
_apply_group_offloading_diffusers_block(
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking
)
else:
_apply_group_offloading_group_patterns(
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking
)


def _apply_group_offloading_diffusers_block(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
force_offload: bool,
non_blocking: bool,
) -> None:
# Handle device offloading/onloading for unet/transformer stack modules
for stack_identifier in _COMMON_STACK_IDENTIFIERS:
if not hasattr(module, stack_identifier) or not isinstance(
getattr(module, stack_identifier), torch.nn.ModuleList
):
continue

stack = getattr(module, stack_identifier)
num_blocks = len(stack)

for i in range(0, num_blocks, num_blocks_per_group):
blocks = stack[i : i + num_blocks_per_group]
group = ModuleGroup(
blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0]
)
should_offload = force_offload or i > 0
_apply_group_offloading(group, should_offload, non_blocking)

# Handle device offloading/onloading for non-stack modules
for name, submodule in module.named_modules():
name_split = name.split(".")
if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1:
# We only want the layers that are top-level in the module (encompass all the submodules)
# for enabling offloading.
continue
layer_name = name_split[0]
print(layer_name)
if layer_name in _COMMON_STACK_IDENTIFIERS:
continue
group = ModuleGroup(
[submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule
)
_apply_group_offloading(group, force_offload, non_blocking)

# Always keep parameters and buffers on onload_device
for name, param in module.named_parameters(recurse=False):
if torch.is_tensor(param.data):
param.data = param.data.to(onload_device)
for name, buffer in module.named_buffers(recurse=False):
if torch.is_tensor(buffer.data):
buffer.data = buffer.data.to(onload_device)


def _apply_group_offloading_group_patterns(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be consolidated into a single function and use the offload_group_pattern. If we add something like a _group_offload_modules to the Model class, we can just extend it with the offload_group_patterns argument here.

module: torch.nn.Module,
offload_group_patterns: List[str],
offload_device: torch.device,
onload_device: torch.device,
force_offload: bool,
non_blocking: bool,
) -> None:
per_group_modules = []
for i, offload_group_pattern in enumerate(offload_group_patterns):
group_modules = []
group_module_names = []
for name, module in module.named_modules():
if re.search(offload_group_pattern, name) is not None:
group_modules.append(module)
group_module_names.append(name)
per_group_modules.append(
{
"modules": group_modules,
"module_names": group_module_names,
}
)

# Check if there are any overlapping modules between groups
for i, group in enumerate(per_group_modules):
for j, other_group in enumerate(per_group_modules):
if j <= i:
continue
if any(module_name in group["module_names"] for module_name in other_group["module_names"]):
raise ValueError(
f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive."
)

# Apply offloading to each group
for group in per_group_modules:
# TODO: handle offload leader correctly
group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1])
_apply_group_offloading(group, force_offload, non_blocking)


def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None:
for module in group.modules:
hook = GroupOffloadingHook(group, offload_on_init, non_blocking)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, "group_offloading")
164 changes: 164 additions & 0 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, Tuple

import torch

from ..utils.logging import get_logger


logger = get_logger(__name__) # pylint: disable=invalid-name


class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""

_is_stateful = False

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module

def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.

Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.

Returns:
`Any`: The processed `output`.
"""
return output

def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.

Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module

def reset_state(self, module: torch.nn.Module):
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module


class HookRegistry:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good 👍🏽

def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()

self.hooks: Dict[str, ModelHook] = {}

self._module_ref = module_ref
self._hook_order = []

def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")

if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward

self._module_ref = hook.initialize_hook(self._module_ref)

if hasattr(hook, "new_forward"):
new_forward = hook.new_forward
else:

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)

new_forward = functools.update_wrapper(new_forward, old_forward)
self._module_ref.forward = new_forward.__get__(self._module_ref)

self.hooks[name] = hook
self._hook_order.append(name)

def get_hook(self, name: str) -> ModelHook:
if name not in self.hooks.keys():
raise ValueError(f"Hook with name {name} not found.")
return self.hooks[name]

def remove_hook(self, name: str) -> None:
if name not in self.hooks.keys():
raise ValueError(f"Hook with name {name} not found.")
self.hooks[name].deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)

@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
if not hasattr(module, "_diffusers_hook"):
module._diffusers_hook = cls(module)
return module._diffusers_hook

def __repr__(self) -> str:
hook_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
Loading