Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 84 additions & 54 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple

import torch
Expand Down Expand Up @@ -56,23 +56,58 @@ def __init__(
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
self.modules = modules
self.offload_device = offload_device
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader
self.parameters = parameters
self.buffers = buffers
self.parameters = parameters or []
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage

if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
self.cpu_param_dict = self._init_cpu_param_dict()

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
return cpu_param_dict

for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)

for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()

for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()

return cpu_param_dict

@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor

yield pinned_dict

finally:
pinned_dict = None

def onload_(self):
r"""Onloads the group of modules to the onload_device."""
Expand All @@ -82,15 +117,30 @@ def onload_(self):
self.stream.synchronize()

with context:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)

for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)

for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)

else:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)

for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:

for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)

Expand All @@ -101,21 +151,18 @@ def offload_(self):
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if self.parameters is not None:
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]

else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)


class GroupOffloadingHook(ModelHook):
Expand Down Expand Up @@ -284,6 +331,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -365,10 +413,12 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")

_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")

Expand All @@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
Expand All @@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""

# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
Expand All @@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
matched_module_groups.append(group)
Expand Down Expand Up @@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
Expand All @@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
Expand All @@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""

# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
Expand All @@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
Expand Down Expand Up @@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
Expand All @@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
Expand Down Expand Up @@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)


def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict


def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def enable_group_offload(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
Activates group offloading for the current model.
Expand Down Expand Up @@ -584,7 +585,14 @@ def enable_group_offload(
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)

def save_pretrained(
Expand Down
Loading