From 2210d285f12afff98970835619af6125d99715d5 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 18 Mar 2025 13:27:09 +0100 Subject: [PATCH 1/3] update --- src/diffusers/hooks/group_offloading.py | 125 ++++++++++++++---------- src/diffusers/models/modeling_utils.py | 10 +- 2 files changed, 84 insertions(+), 51 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index e4b9ed9307ea..9f7d30d3f595 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple import torch @@ -56,7 +56,7 @@ def __init__( buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, stream: Optional[torch.cuda.Stream] = None, - cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + low_cpu_mem_usage=False, onload_self: bool = True, ) -> None: self.modules = modules @@ -64,15 +64,42 @@ def __init__( self.onload_device = onload_device self.offload_leader = offload_leader self.onload_leader = onload_leader - self.parameters = parameters - self.buffers = buffers + self.parameters = parameters or [] + self.buffers = buffers or [] self.non_blocking = non_blocking or stream is not None self.stream = stream - self.cpu_param_dict = cpu_param_dict self.onload_self = onload_self + self.low_cpu_mem_usage = low_cpu_mem_usage - if self.stream is not None and self.cpu_param_dict is None: - raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") + self.cpu_param_dict = {} + for module in self.modules: + for param in module.parameters(): + self.cpu_param_dict[param] = ( + param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + ) + + for param in self.parameters: + self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + + for buffer in self.buffers: + self.cpu_param_dict[buffer] = ( + buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + ) + + @contextmanager + def _pinned_memory_tensors(self): + pinned_dict = {} + try: + for param, tensor in self.cpu_param_dict.items(): + if not tensor.is_pinned(): + pinned_dict[param] = tensor.pin_memory() + else: + pinned_dict[param] = tensor + + yield pinned_dict + + finally: + pinned_dict = None def onload_(self): r"""Onloads the group of modules to the onload_device.""" @@ -82,17 +109,32 @@ def onload_(self): self.stream.synchronize() with context: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for buffer in group_module.buffers(): - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.parameters is not None: - for param in self.parameters: - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.buffers is not None: - for buffer in self.buffers: - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + for group_module in self.modules: + for param in group_module.parameters(): + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + + if self.parameters is not None: + for param in self.parameters: + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + + else: + for group_module in self.modules: + for param in group_module.parameters(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + + if self.parameters is not None: + for param in self.parameters: + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) def offload_(self): r"""Offloads the group of modules to the offload_device.""" @@ -108,12 +150,12 @@ def offload_(self): for buffer in self.buffers: buffer.data = self.cpu_param_dict[buffer] else: - for group_module in self.modules: - group_module.to(self.offload_device, non_blocking=self.non_blocking) - if self.parameters is not None: + for module in self.modules: + module.to(self.offload_device, non_blocking=self.non_blocking) + if self.parameters: for param in self.parameters: param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) - if self.buffers is not None: + if self.buffers: for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) @@ -284,6 +326,7 @@ def apply_group_offloading( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + low_cpu_mem_usage=False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -365,10 +408,12 @@ def apply_group_offloading( raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream + module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage ) elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) + _apply_group_offloading_leaf_level( + module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage + ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -380,6 +425,7 @@ def _apply_group_offloading_block_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + low_cpu_mem_usage: bool = False, ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -400,11 +446,6 @@ def _apply_group_offloading_block_level( for overlapping computation and data transfer. """ - # Create a pinned CPU parameter dict for async data transfer if streams are to be used - cpu_param_dict = None - if stream is not None: - cpu_param_dict = _get_pinned_cpu_param_dict(module) - # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() unmatched_modules = [] @@ -425,7 +466,7 @@ def _apply_group_offloading_block_level( onload_leader=current_modules[0], non_blocking=non_blocking, stream=stream, - cpu_param_dict=cpu_param_dict, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=stream is None, ) matched_module_groups.append(group) @@ -462,7 +503,6 @@ def _apply_group_offloading_block_level( buffers=buffers, non_blocking=False, stream=None, - cpu_param_dict=None, onload_self=True, ) next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None @@ -475,6 +515,7 @@ def _apply_group_offloading_leaf_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + low_cpu_mem_usage: bool = False, ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory @@ -497,11 +538,6 @@ def _apply_group_offloading_leaf_level( for overlapping computation and data transfer. """ - # Create a pinned CPU parameter dict for async data transfer if streams are to be used - cpu_param_dict = None - if stream is not None: - cpu_param_dict = _get_pinned_cpu_param_dict(module) - # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): @@ -515,7 +551,7 @@ def _apply_group_offloading_leaf_level( onload_leader=submodule, non_blocking=non_blocking, stream=stream, - cpu_param_dict=cpu_param_dict, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) _apply_group_offloading_hook(submodule, group, None) @@ -560,7 +596,7 @@ def _apply_group_offloading_leaf_level( buffers=buffers, non_blocking=non_blocking, stream=stream, - cpu_param_dict=cpu_param_dict, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) _apply_group_offloading_hook(parent_module, group, None) @@ -579,7 +615,7 @@ def _apply_group_offloading_leaf_level( buffers=None, non_blocking=False, stream=None, - cpu_param_dict=None, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) @@ -616,17 +652,6 @@ def _apply_lazy_group_offloading_hook( registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) -def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]: - cpu_param_dict = {} - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict[param] = param.data - for buffer in module.buffers(): - buffer.data = buffer.data.cpu().pin_memory() - cpu_param_dict[buffer] = buffer.data - return cpu_param_dict - - def _gather_parameters_with_no_group_offloading_parent( module: torch.nn.Module, modules_with_group_offloading: Set[str] ) -> List[torch.nn.Parameter]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6983940f139b..351ce7b1772c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -546,6 +546,7 @@ def enable_group_offload( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + low_cpu_mem_usage=False, ) -> None: r""" Activates group offloading for the current model. @@ -584,7 +585,14 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream + self, + onload_device, + offload_device, + offload_type, + num_blocks_per_group, + non_blocking, + use_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) def save_pretrained( From aead2c50507c0322856a359a2040132fe7b13c89 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 18 Mar 2025 17:27:07 +0100 Subject: [PATCH 2/3] update --- src/diffusers/hooks/group_offloading.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 9f7d30d3f595..7f414d5f691b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -77,6 +77,10 @@ def __init__( self.cpu_param_dict[param] = ( param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() ) + for buffer in module.buffers(): + self.cpu_param_dict[buffer] = ( + buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + ) for param in self.parameters: self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() @@ -127,6 +131,8 @@ def onload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for param in group_module.buffers(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) if self.parameters is not None: for param in self.parameters: From 310bd1050849db583034e719efaeb447d3e61bc1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 18 Mar 2025 17:55:40 +0100 Subject: [PATCH 3/3] clean up --- src/diffusers/hooks/group_offloading.py | 73 ++++++++++++------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7f414d5f691b..11e2db78723a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -71,24 +71,28 @@ def __init__( self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = {} + self.cpu_param_dict = self._init_cpu_param_dict() + + def _init_cpu_param_dict(self): + cpu_param_dict = {} + if self.stream is None: + return cpu_param_dict + for module in self.modules: for param in module.parameters(): - self.cpu_param_dict[param] = ( - param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() - ) + cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() for buffer in module.buffers(): - self.cpu_param_dict[buffer] = ( + cpu_param_dict[buffer] = ( buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() ) for param in self.parameters: - self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() for buffer in self.buffers: - self.cpu_param_dict[buffer] = ( - buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() - ) + cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + + return cpu_param_dict @contextmanager def _pinned_memory_tensors(self): @@ -118,29 +122,27 @@ def onload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + for buffer in group_module.buffers(): + buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.parameters is not None: - for param in self.parameters: - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + for param in self.parameters: + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.buffers is not None: - for buffer in self.buffers: - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + for buffer in self.buffers: + buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) else: for group_module in self.modules: for param in group_module.parameters(): param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for param in group_module.buffers(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for buffer in group_module.buffers(): + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.parameters is not None: - for param in self.parameters: - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for param in self.parameters: + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.buffers is not None: - for buffer in self.buffers: - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + for buffer in self.buffers: + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) def offload_(self): r"""Offloads the group of modules to the offload_device.""" @@ -149,21 +151,18 @@ def offload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] - if self.parameters is not None: - for param in self.parameters: - param.data = self.cpu_param_dict[param] - if self.buffers is not None: - for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] + for param in self.parameters: + param.data = self.cpu_param_dict[param] + for buffer in self.buffers: + buffer.data = self.cpu_param_dict[buffer] + else: - for module in self.modules: - module.to(self.offload_device, non_blocking=self.non_blocking) - if self.parameters: - for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) - if self.buffers: - for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + for group_module in self.modules: + group_module.to(self.offload_device, non_blocking=self.non_blocking) + for param in self.parameters: + param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) + for buffer in self.buffers: + buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) class GroupOffloadingHook(ModelHook):