Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 94 additions & 6 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union

import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available
Expand Down Expand Up @@ -59,6 +61,8 @@ def __init__(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk: bool = False,
offload_path: Optional[str] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand All @@ -72,7 +76,29 @@ def __init__(
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict()

self.offload_to_disk = offload_to_disk
self.offload_path = offload_path
self._is_offloaded_to_disk = False

if self.offload_to_disk:
if self.offload_path is None:
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
self.safetensors_file_path = os.path.join(self.offload_path, f"group_{id(self)}.safetensors")

all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()

if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
Expand Down Expand Up @@ -124,6 +150,30 @@ def onload_(self):
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None

if self.offload_to_disk:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
Comment on lines +156 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cleaner approach would be to provide a callable to map_location (assuming we were using torch.load instead of safetensors), which for each tensor can pin and move to device. Do we know if there is a equivalent to passing a callable with safetensors? If not, this is okay too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if there would be other alternatives to this code path? If not, I think it's better as is. From skimming through the documentation of safetensors, I couldn't find any equivalent of map_location.

else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
Expand Down Expand Up @@ -169,6 +219,18 @@ def onload_(self):
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): we probably need to refactor this a bit and break into smaller methods so we don't have to branch and do early-returns every time a new feature is added (we can do refactor once we have everything working, so not urgent)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Can I do it in an immediate follow-up PR so that it's easier to review?

if not self._is_offloaded_to_disk:
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
self._is_offloaded_to_disk = True

for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this to be different from the non-disk-offload counterpart? That is, is there a reason we're not doing buffer.data.to(self.offload_device, non_blocking=self.non_blocking)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we first free up the memory of the accelerator with:

key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()

However, since we're also optimizing for RAM usage (can be made clearer through documentation I believe), we need to free up the RAM that is holding the tensor data. After the data has been safely written from RAM to the disk, this step replaces the large data tensor in RAM with a memory-less placeholder. This allows the memory to be released.

return

torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
Expand Down Expand Up @@ -205,13 +267,12 @@ class GroupOffloadingHook(ModelHook):

_is_stateful = False

def __init__(
self,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
self.group = group
self.next_group = next_group
# map param/buffer name -> file path
self.param_to_path: Dict[str, str] = {}
self.buffer_to_path: Dict[str, str] = {}

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
Expand Down Expand Up @@ -358,6 +419,8 @@ def apply_group_offloading(
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_to_disk: bool = False,
offload_path: Optional[str] = None,
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
Expand Down Expand Up @@ -401,6 +464,11 @@ def apply_group_offloading(
offload_type (`str`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk (`bool`, defaults to `False`):
If `True`, offload model parameters to disk instead of CPU RAM. This is useful when CPU memory is limited.
Requires `offload_path` to be set.
offload_path (`str`, *optional*):
The path to the directory where offloaded parameters will be stored when `offload_to_disk` is `True`.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
Expand Down Expand Up @@ -446,6 +514,8 @@ def apply_group_offloading(
stream = torch.Stream()
else:
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
if offload_to_disk and offload_path is None:
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")

_raise_error_if_accelerate_model_or_sequential_hook_present(module)

Expand All @@ -458,6 +528,8 @@ def apply_group_offloading(
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
Expand All @@ -468,6 +540,8 @@ def apply_group_offloading(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
Expand All @@ -481,6 +555,8 @@ def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
offload_to_disk: bool,
offload_path: Optional[str],
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
Expand Down Expand Up @@ -535,6 +611,8 @@ def _apply_group_offloading_block_level(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
Expand Down Expand Up @@ -567,6 +645,8 @@ def _apply_group_offloading_block_level(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
Expand All @@ -586,6 +666,8 @@ def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
offload_to_disk: bool,
offload_path: Optional[str],
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
Expand Down Expand Up @@ -629,6 +711,8 @@ def _apply_group_offloading_leaf_level(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
Expand Down Expand Up @@ -675,6 +759,8 @@ def _apply_group_offloading_leaf_level(
onload_device=onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
Expand All @@ -693,6 +779,8 @@ def _apply_group_offloading_leaf_level(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
offload_leader=module,
onload_leader=module,
parameters=None,
Expand Down
20 changes: 12 additions & 8 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def enable_group_offload(
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_to_disk: bool = False,
offload_path: Optional[str] = None,
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
Expand Down Expand Up @@ -588,15 +590,17 @@ def enable_group_offload(
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
module=self,
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
use_stream=use_stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk=offload_to_disk,
offload_path=offload_path,
)

def save_pretrained(
Expand Down
Loading