Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
Expand All @@ -68,11 +69,14 @@ def __init__(
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage

self.cpu_param_dict = self._init_cpu_param_dict()

if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand Down Expand Up @@ -112,6 +116,8 @@ def _pinned_memory_tensors(self):
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
Expand All @@ -122,14 +128,22 @@ def onload_(self):
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)

for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)

for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)

else:
for group_module in self.modules:
Expand All @@ -143,11 +157,14 @@ def onload_(self):

for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)

def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
if not self.record_stream:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
Expand Down Expand Up @@ -331,6 +348,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -378,6 +396,10 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
Expand Down Expand Up @@ -417,11 +439,24 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")

_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
module=module,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
Expand All @@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
Expand All @@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""

# Create module groups for ModuleList and Sequential blocks
Expand All @@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
Expand Down Expand Up @@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
Expand All @@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
Expand All @@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""

# Create module groups for leaf modules and apply group offloading hooks
Expand All @@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
Expand Down Expand Up @@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
Expand All @@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def enable_group_offload(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
Expand Down Expand Up @@ -594,6 +595,7 @@ def enable_group_offload(
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)

Expand Down
7 changes: 5 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,8 +1525,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)

@parameterized.expand([False, True])
@require_torch_gpu
def test_group_offloading(self):
def test_group_offloading(self, record_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)

Expand Down Expand Up @@ -1566,7 +1567,9 @@ def run_forward(model):

torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)

self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
Expand Down
Loading