Skip to content

Commit 4fa85c7

Browse files
committed
add model_manager and global offloading method
1 parent 806e8e6 commit 4fa85c7

File tree

4 files changed

+337
-10
lines changed

4 files changed

+337
-10
lines changed

src/diffusers/guider.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,21 @@
3232

3333
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
3434
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
35-
"""
36-
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
37-
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
35+
r"""
36+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
37+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
38+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
39+
40+
Args:
41+
noise_cfg (`torch.Tensor`):
42+
The predicted noise tensor for the guided diffusion process.
43+
noise_pred_text (`torch.Tensor`):
44+
The predicted noise tensor for the text-guided diffusion process.
45+
guidance_rescale (`float`, *optional*, defaults to 0.0):
46+
A rescale factor applied to the noise predictions.
47+
48+
Returns:
49+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
3850
"""
3951
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
4052
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,6 @@ def prepare_latents(
877877
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
878878
)
879879

880-
latents_mean = latents_std = None
881-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
882-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
883-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
884-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
885-
886880
# Offload text encoder if `enable_model_cpu_offload` was enabled
887881
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
888882
self.text_encoder_2.to("cpu")
@@ -896,6 +890,11 @@ def prepare_latents(
896890
init_latents = image
897891

898892
else:
893+
latents_mean = latents_std = None
894+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
895+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
896+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
897+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
899898
# make sure the VAE is in float32 mode, as it overflows in float16
900899
if self.vae.config.force_upcast:
901900
image = image.float()
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
from itertools import combinations
17+
from typing import List, Optional, Union
18+
19+
import torch
20+
21+
from ..utils import (
22+
is_accelerate_available,
23+
logging,
24+
)
25+
26+
27+
if is_accelerate_available():
28+
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
29+
from accelerate.state import PartialState
30+
from accelerate.utils import send_to_device
31+
from accelerate.utils.memory import clear_device_cache
32+
33+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34+
35+
36+
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
37+
def get_memory_footprint(self, return_buffers=True):
38+
r"""
39+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
40+
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
41+
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
42+
43+
Arguments:
44+
return_buffers (`bool`, *optional*, defaults to `True`):
45+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
46+
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
47+
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
48+
"""
49+
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
50+
if return_buffers:
51+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
52+
mem = mem + mem_bufs
53+
return mem
54+
55+
56+
class CustomOffloadHook(ModelHook):
57+
"""
58+
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
59+
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
60+
61+
Args:
62+
execution_device(`str`, `int` or `torch.device`, *optional*):
63+
The device on which the model should be executed. Will default to the MPS device if it's available, then
64+
GPU 0 if there is a GPU, and finally to the CPU.
65+
"""
66+
67+
def __init__(
68+
self,
69+
execution_device: Optional[Union[str, int, torch.device]] = None,
70+
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
71+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
72+
):
73+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
74+
self.other_hooks = other_hooks
75+
self.offload_strategy = offload_strategy
76+
self.model_id = None
77+
78+
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
79+
self.offload_strategy = offload_strategy
80+
81+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
82+
"""
83+
Add a hook to the list of hooks to consider for offloading.
84+
"""
85+
if self.other_hooks is None:
86+
self.other_hooks = []
87+
self.other_hooks.append(hook)
88+
89+
def init_hook(self, module):
90+
return module.to("cpu")
91+
92+
def pre_forward(self, module, *args, **kwargs):
93+
if module.device != self.execution_device:
94+
if self.other_hooks is not None:
95+
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
96+
# offload all other hooks
97+
import time
98+
99+
# YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later)
100+
start_time = time.perf_counter()
101+
if self.offload_strategy is not None:
102+
hooks_to_offload = self.offload_strategy(
103+
hooks=hooks_to_offload,
104+
model_id=self.model_id,
105+
model=module,
106+
execution_device=self.execution_device,
107+
)
108+
end_time = time.perf_counter()
109+
logger.info(
110+
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
111+
)
112+
113+
for hook in hooks_to_offload:
114+
logger.info(
115+
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
116+
)
117+
hook.offload()
118+
119+
if hooks_to_offload:
120+
clear_device_cache()
121+
module.to(self.execution_device)
122+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
123+
124+
125+
class UserCustomOffloadHook:
126+
"""
127+
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
128+
the hook or remove it entirely.
129+
"""
130+
131+
def __init__(self, model_id, model, hook):
132+
self.model_id = model_id
133+
self.model = model
134+
self.hook = hook
135+
136+
def offload(self):
137+
self.hook.init_hook(self.model)
138+
139+
def attach(self):
140+
add_hook_to_module(self.model, self.hook)
141+
self.hook.model_id = self.model_id
142+
143+
def remove(self):
144+
remove_hook_from_module(self.model)
145+
self.hook.model_id = None
146+
147+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
148+
self.hook.add_other_hook(hook)
149+
150+
151+
def custom_offload_with_hook(
152+
model_id: str,
153+
model: torch.nn.Module,
154+
execution_device: Union[str, int, torch.device] = None,
155+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
156+
):
157+
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
158+
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
159+
user_hook.attach()
160+
return user_hook
161+
162+
163+
class AutoOffloadStrategy:
164+
"""
165+
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
166+
the available memory on the device.
167+
"""
168+
169+
def __init__(self, size_estimation_margin=0.1):
170+
self.size_estimation_margin = size_estimation_margin
171+
172+
def __call__(self, hooks, model_id, model, execution_device):
173+
if len(hooks) == 0:
174+
return []
175+
176+
current_module_size = get_memory_footprint(model)
177+
current_module_size *= 1 + self.size_estimation_margin
178+
179+
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
180+
if current_module_size < mem_on_device:
181+
return []
182+
183+
min_memory_offload = current_module_size - mem_on_device
184+
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
185+
186+
# exlucde models that's not currently loaded on the device
187+
module_sizes = dict(
188+
sorted(
189+
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
190+
key=lambda x: x[1],
191+
reverse=True,
192+
)
193+
)
194+
195+
def search_best_candidate(module_sizes, min_memory_offload):
196+
"""
197+
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
198+
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
199+
larger than `min_memory_offload`
200+
"""
201+
model_ids = list(module_sizes.keys())
202+
best_candidate = None
203+
best_size = float("inf")
204+
for r in range(1, len(model_ids) + 1):
205+
for candidate_model_ids in combinations(model_ids, r):
206+
candidate_size = sum(
207+
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
208+
)
209+
if candidate_size < min_memory_offload:
210+
continue
211+
else:
212+
if best_candidate is None or candidate_size < best_size:
213+
best_candidate = candidate_model_ids
214+
best_size = candidate_size
215+
216+
return best_candidate
217+
218+
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
219+
220+
if best_offload_model_ids is None:
221+
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
222+
logger.warning("no combination of models to offload to cpu is found, offloading all models")
223+
hooks_to_offload = hooks
224+
else:
225+
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
226+
227+
return hooks_to_offload
228+
229+
230+
class ModelManager:
231+
def __init__(self):
232+
self.models = OrderedDict()
233+
self.model_hooks = None
234+
self._auto_offload_enabled = False
235+
236+
def add(self, model_id, model):
237+
if model_id not in self.models:
238+
self.models[model_id] = model
239+
if self._auto_offload_enabled:
240+
self.enable_auto_cpu_offload(self._auto_offload_device)
241+
242+
def remove(self, model_id):
243+
self.models.pop(model_id)
244+
if self._auto_offload_enabled:
245+
self.enable_auto_cpu_offload(self._auto_offload_device)
246+
247+
def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1):
248+
for model_id, model in self.models.items():
249+
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
250+
remove_hook_from_module(model, recurse=True)
251+
252+
self.disable_auto_cpu_offload()
253+
offload_strategy = AutoOffloadStrategy(size_estimation_margin=size_estimation_margin)
254+
device = torch.device(device)
255+
if device.index is None:
256+
device = torch.device(f"{device.type}:{0}")
257+
all_hooks = []
258+
for model_id, model in self.models.items():
259+
hook = custom_offload_with_hook(model_id, model, device, offload_strategy=offload_strategy)
260+
all_hooks.append(hook)
261+
262+
for hook in all_hooks:
263+
other_hooks = [h for h in all_hooks if h is not hook]
264+
for other_hook in other_hooks:
265+
if other_hook.hook.execution_device == hook.hook.execution_device:
266+
hook.add_other_hook(other_hook)
267+
268+
self.model_hooks = all_hooks
269+
self._auto_offload_enabled = True
270+
self._auto_offload_device = device
271+
272+
def disable_auto_cpu_offload(self):
273+
if self.model_hooks is None:
274+
self._auto_offload_enabled = False
275+
return
276+
277+
for hook in self.model_hooks:
278+
hook.offload()
279+
hook.remove()
280+
if self.model_hooks:
281+
clear_device_cache()
282+
self.model_hooks = None
283+
self._auto_offload_enabled = False
284+
285+
def __repr__(self):
286+
col_widths = {
287+
"id": max(15, max(len(id) for id in self.models.keys())),
288+
"class": max(25, max(len(model.__class__.__name__) for model in self.models.values())),
289+
"device": 10,
290+
"dtype": 15,
291+
"size": 10,
292+
}
293+
294+
# Create the header
295+
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
296+
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
297+
298+
output = "ModelManager:\n" + sep_line
299+
300+
# Column headers
301+
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
302+
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n"
303+
output += dash_line
304+
305+
# Model entries
306+
for model_id, model in self.models.items():
307+
device = model.device
308+
dtype = model.dtype
309+
size_bytes = get_memory_footprint(model)
310+
size_gb = size_bytes / (1024**3)
311+
312+
output += f"{model_id:<{col_widths['id']}} | {model.__class__.__name__:<{col_widths['class']}} | "
313+
output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n"
314+
315+
output += sep_line
316+
return output

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def retrieve_timesteps(
5858
sigmas: Optional[List[float]] = None,
5959
**kwargs,
6060
):
61-
"""
61+
r"""
6262
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
6363
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
6464

0 commit comments

Comments
 (0)