diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 45fee35ef336..ac25dd061b39 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -14,6 +14,8 @@ import os from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union import safetensors.torch @@ -46,6 +48,24 @@ # fmt: on +class GroupOffloadingType(str, Enum): + BLOCK_LEVEL = "block_level" + LEAF_LEVEL = "leaf_level" + + +@dataclass +class GroupOffloadingConfig: + onload_device: torch.device + offload_device: torch.device + offload_type: GroupOffloadingType + non_blocking: bool + record_stream: bool + low_cpu_mem_usage: bool + num_blocks_per_group: Optional[int] = None + offload_to_disk_path: Optional[str] = None + stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + + class ModuleGroup: def __init__( self, @@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False - def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: + def __init__( + self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig + ) -> None: self.group = group self.next_group = next_group + self.config = config def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -436,7 +459,7 @@ def apply_group_offloading( module: torch.nn.Module, onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), - offload_type: str = "block_level", + offload_type: Union[str, GroupOffloadingType] = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, @@ -478,7 +501,7 @@ def apply_group_offloading( The device to which the group of modules are onloaded. offload_device (`torch.device`, defaults to `torch.device("cpu")`): The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. - offload_type (`str`, defaults to "block_level"): + offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". offload_to_disk_path (`str`, *optional*, defaults to `None`): @@ -521,6 +544,8 @@ def apply_group_offloading( ``` """ + offload_type = GroupOffloadingType(offload_type) + stream = None if use_stream: if torch.cuda.is_available(): @@ -532,84 +557,45 @@ def apply_group_offloading( if not use_stream and record_stream: raise ValueError("`record_stream` cannot be True when `use_stream=False`.") + if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: + raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) - if offload_type == "block_level": - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") - - _apply_group_offloading_block_level( - module=module, - num_blocks_per_group=num_blocks_per_group, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level( - module=module, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + config = GroupOffloadingConfig( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, + ) + _apply_group_offloading(module, config) + + +def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: + if config.offload_type == GroupOffloadingType.BLOCK_LEVEL: + _apply_group_offloading_block_level(module, config) + elif config.offload_type == GroupOffloadingType.LEAF_LEVEL: + _apply_group_offloading_leaf_level(module, config) else: - raise ValueError(f"Unsupported offload_type: {offload_type}") + assert False -def _apply_group_offloading_block_level( - module: torch.nn.Module, - num_blocks_per_group: int, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - offload_to_disk_path: Optional[str] = None, -) -> None: +def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - offload_to_disk_path (`str`, *optional*, defaults to `None`): - The path to the directory where parameters will be offloaded. Setting this option can be useful in limited - RAM environment settings where a reasonable speed-memory trade-off is desired. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream`or `torch.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ - if stream is not None and num_blocks_per_group != 1: + + if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( - f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1." + f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) - num_blocks_per_group = 1 + config.num_blocks_per_group = 1 # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -621,19 +607,19 @@ def _apply_group_offloading_block_level( modules_with_group_offloading.add(name) continue - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = submodule[i : i + config.num_blocks_per_group] group = ModuleGroup( modules=current_modules, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) matched_module_groups.append(group) @@ -643,7 +629,7 @@ def _apply_group_offloading_block_level( # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, None) + _apply_group_offloading_hook(group_module, group, None, config=config) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -658,9 +644,9 @@ def _apply_group_offloading_block_level( unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( modules=unmatched_modules, - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -670,54 +656,19 @@ def _apply_group_offloading_block_level( record_stream=False, onload_self=True, ) - if stream is None: - _apply_group_offloading_hook(module, unmatched_group, None) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, None, config=config) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) -def _apply_group_offloading_leaf_level( - module: torch.nn.Module, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, - offload_to_disk_path: Optional[str] = None, -) -> None: +def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory requirements. However, it can be slower compared to other offloading methods due to the excessive number of device synchronizations. When using devices that support streams to overlap data transfer and computation, this method can reduce memory usage without any performance degradation. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - offload_to_disk_path (`str`, *optional*, defaults to `None`): - The path to the directory where parameters will be offloaded. Setting this option can be useful in limited - RAM environment settings where a reasonable speed-memory trade-off is desired. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream` or `torch.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ - # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): @@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level( continue group = ModuleGroup( modules=[submodule], - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_group_offloading_hook(submodule, group, None) + _apply_group_offloading_hook(submodule, group, None, config=config) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level( parameters = parent_to_parameters.get(name, []) buffers = parent_to_buffers.get(name, []) parent_module = module_dict[name] - assert getattr(parent_module, "_diffusers_hook", None) is None group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, offload_leader=parent_module, onload_leader=parent_module, - offload_to_disk_path=offload_to_disk_path, + offload_to_disk_path=config.offload_to_disk_path, parameters=parameters, buffers=buffers, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_group_offloading_hook(parent_module, group, None) + _apply_group_offloading_hook(parent_module, group, None, config=config) - if stream is not None: + if config.stream is not None: # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # execution order and apply prefetching in the correct order. unmatched_group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, - offload_to_disk_path=offload_to_disk_path, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, @@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level( non_blocking=False, stream=None, record_stream=False, - low_cpu_mem_usage=low_cpu_mem_usage, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, ) - _apply_lazy_group_offloading_hook(module, unmatched_group, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) @@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() @@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn ) -def _is_group_offload_enabled(module: torch.nn.Module) -> bool: +def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]: for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return True - return False + if hasattr(submodule, "_diffusers_hook"): + group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) + if group_offloading_hook is not None: + return group_offloading_hook + return None + + +def _is_group_offload_enabled(module: torch.nn.Module) -> bool: + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + return top_level_group_offload_hook is not None def _get_group_onload_device(module: torch.nn.Module) -> torch.device: - for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + if top_level_group_offload_hook is not None: + return top_level_group_offload_hook.config.onload_device raise ValueError("Group offloading is not enabled for the provided module.") + + +def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: + r""" + Removes the group offloading hook from the module and re-applies it. This is useful when the module has been + modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place + modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly. + + In this implementation, we make an assumption that group offloading has only been applied at the top-level module, + and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the + case where user has applied group offloading at multiple levels, this function will not work as expected. + + There is some performance penalty associated with doing this when non-default streams are used, because we need to + retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`. + """ + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + + if top_level_group_offload_hook is None: + return + + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.remove_hook(_GROUP_OFFLOADING, recurse=True) + registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) + registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) + + _apply_group_offloading(module, top_level_group_offload_hook.config) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e6941a521d06..562a21dbbb74 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,6 +25,7 @@ from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE +from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, @@ -391,7 +392,9 @@ def _load_lora_into_text_encoder( adapter_name = get_adapter_name(text_encoder) # if prefix is not None and not state_dict: @@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline): Returns: tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) + if not isinstance(component, nn.Module): + continue + is_group_offload = is_group_offload or _is_group_offload_enabled(component) + if not hasattr(component, "_hf_hook"): + continue + is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) + is_sequential_cpu_offload = is_sequential_cpu_offload or ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - if is_sequential_cpu_offload or is_model_cpu_offload: - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + if is_sequential_cpu_offload or is_model_cpu_offload: + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + for _, component in _pipeline.components.items(): + if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): + continue + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - return (is_model_cpu_offload, is_sequential_cpu_offload) + return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) class LoraBaseMixin: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 343623071340..3670243de859 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -22,6 +22,7 @@ import safetensors import torch +from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, @@ -256,7 +257,9 @@ def load_lora_adapter( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error. - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline + ) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -347,6 +350,10 @@ def map_state_dict_for_hotswap(sd): _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() + elif is_group_offload: + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> if prefix is not None and not state_dict: @@ -687,6 +694,8 @@ def unload_lora(self): if hasattr(self, "peft_config"): del self.peft_config + _maybe_remove_and_reapply_group_offloading(self) + def disable_lora(self): """ Disables the active LoRA layers of the underlying model. diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 68be84119177..c9b6a7d7d862 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args +from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, @@ -203,6 +204,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False if is_lora: deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`." @@ -211,7 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict if is_custom_diffusion: attn_processors = self._process_custom_diffusion(state_dict=state_dict) elif is_lora: - is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora( state_dict=state_dict, unet_identifier_key=self.unet_name, network_alphas=network_alphas, @@ -230,7 +232,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`. if is_custom_diffusion and _pipeline is not None: - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline=_pipeline + ) # only custom diffusion needs to set attn processors self.set_attn_processor(attn_processors) @@ -241,6 +245,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() + elif is_group_offload: + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> def _process_custom_diffusion(self, state_dict): @@ -307,6 +315,7 @@ def _process_lora( is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict if len(state_dict_to_be_used) > 0: @@ -356,7 +365,9 @@ def _process_lora( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading( + _pipeline + ) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -389,7 +400,7 @@ def _process_lora( if warn_msg: logger.warning(warn_msg) - return is_model_cpu_offload, is_sequential_cpu_offload + return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index bd7b33445c37..565d6db69727 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -16,6 +16,7 @@ import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( @@ -28,6 +29,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, require_peft_backend, + require_torch_accelerator, ) @@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_lora_scale_kwargs_match_fusion(self): super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) + @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. + # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 + super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 23573bcb214e..b7367d9b0946 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -18,10 +18,17 @@ import numpy as np import torch +from parameterized import parameterized from transformers import AutoTokenizer, GlmModel from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + floats_tensor, + require_peft_backend, + require_torch_accelerator, + skip_mps, + torch_device, +) sys.path.append(".") @@ -141,6 +148,13 @@ def test_simple_inference_save_pretrained(self): "Loading from saved checkpoints should give same results.", ) + @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. + # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 + super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + @unittest.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 93dc4a2c37e3..acd6f5f34361 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -39,6 +39,7 @@ is_torch_version, require_peft_backend, require_peft_version_greater, + require_torch_accelerator, require_transformers_version_greater, skip_mps, torch_device, @@ -2355,3 +2356,73 @@ def test_inference_load_delete_load_adapters(self): pipe.load_lora_weights(tmpdirname) output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) + + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook + + onload_device = torch_device + offload_device = torch.device("cpu") + + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + + components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + # Test group offloading with load_lora_weights + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_1 is not None) + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + # Test group offloading after removing the lora + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_2 is not None) + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + + # Add the lora again and check if group offloading works + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + self.assertTrue(group_offload_hook_3 is not None) + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) + + @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) + @require_torch_accelerator + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + for cls in inspect.getmro(self.__class__): + if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + return + self._test_group_offloading_inference_denoiser(offload_type, use_stream)