Skip to content

Commit 44fb2fd

Browse files
committed
add
1 parent 74b6752 commit 44fb2fd

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
from itertools import combinations
17+
from typing import List, Optional, Union
18+
19+
import torch
20+
21+
from ..utils import (
22+
is_accelerate_available,
23+
logging,
24+
)
25+
26+
27+
if is_accelerate_available():
28+
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
29+
from accelerate.state import PartialState
30+
from accelerate.utils import send_to_device
31+
from accelerate.utils.memory import clear_device_cache
32+
from accelerate.utils.modeling import convert_file_size_to_int
33+
34+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35+
36+
37+
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
38+
def get_memory_footprint(self, return_buffers=True):
39+
r"""
40+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
41+
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
42+
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
43+
44+
Arguments:
45+
return_buffers (`bool`, *optional*, defaults to `True`):
46+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
47+
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
48+
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
49+
"""
50+
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
51+
if return_buffers:
52+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
53+
mem = mem + mem_bufs
54+
return mem
55+
56+
57+
class CustomOffloadHook(ModelHook):
58+
"""
59+
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
60+
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
61+
62+
Args:
63+
execution_device(`str`, `int` or `torch.device`, *optional*):
64+
The device on which the model should be executed. Will default to the MPS device if it's available, then
65+
GPU 0 if there is a GPU, and finally to the CPU.
66+
"""
67+
68+
def __init__(
69+
self,
70+
execution_device: Optional[Union[str, int, torch.device]] = None,
71+
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
72+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
73+
):
74+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
75+
self.other_hooks = other_hooks
76+
self.offload_strategy = offload_strategy
77+
self.model_id = None
78+
79+
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
80+
self.offload_strategy = offload_strategy
81+
82+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
83+
"""
84+
Add a hook to the list of hooks to consider for offloading.
85+
"""
86+
if self.other_hooks is None:
87+
self.other_hooks = []
88+
self.other_hooks.append(hook)
89+
90+
def init_hook(self, module):
91+
return module.to("cpu")
92+
93+
def pre_forward(self, module, *args, **kwargs):
94+
if module.device != self.execution_device:
95+
if self.other_hooks is not None:
96+
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
97+
# offload all other hooks
98+
import time
99+
100+
# YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later)
101+
start_time = time.perf_counter()
102+
if self.offload_strategy is not None:
103+
hooks_to_offload = self.offload_strategy(
104+
hooks=hooks_to_offload,
105+
model_id=self.model_id,
106+
model=module,
107+
execution_device=self.execution_device,
108+
)
109+
end_time = time.perf_counter()
110+
logger.info(
111+
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
112+
)
113+
114+
for hook in hooks_to_offload:
115+
logger.info(
116+
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
117+
)
118+
hook.offload()
119+
120+
if hooks_to_offload:
121+
clear_device_cache()
122+
module.to(self.execution_device)
123+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
124+
125+
126+
class UserCustomOffloadHook:
127+
"""
128+
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
129+
the hook or remove it entirely.
130+
"""
131+
132+
def __init__(self, model_id, model, hook):
133+
self.model_id = model_id
134+
self.model = model
135+
self.hook = hook
136+
137+
def offload(self):
138+
self.hook.init_hook(self.model)
139+
140+
def attach(self):
141+
add_hook_to_module(self.model, self.hook)
142+
self.hook.model_id = self.model_id
143+
144+
def remove(self):
145+
remove_hook_from_module(self.model)
146+
self.hook.model_id = None
147+
148+
def add_other_hook(self, hook: "UserCustomOffloadHook"):
149+
self.hook.add_other_hook(hook)
150+
151+
152+
def custom_offload_with_hook(
153+
model_id: str,
154+
model: torch.nn.Module,
155+
execution_device: Union[str, int, torch.device] = None,
156+
offload_strategy: Optional["AutoOffloadStrategy"] = None,
157+
):
158+
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
159+
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
160+
user_hook.attach()
161+
return user_hook
162+
163+
164+
class AutoOffloadStrategy:
165+
"""
166+
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
167+
the available memory on the device.
168+
"""
169+
170+
def __init__(self, memory_reserve_margin="3GB"):
171+
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
172+
173+
def __call__(self, hooks, model_id, model, execution_device):
174+
if len(hooks) == 0:
175+
return []
176+
177+
current_module_size = get_memory_footprint(model)
178+
179+
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
180+
mem_on_device = mem_on_device - self.memory_reserve_margin
181+
if current_module_size < mem_on_device:
182+
return []
183+
184+
min_memory_offload = current_module_size - mem_on_device
185+
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
186+
187+
# exlucde models that's not currently loaded on the device
188+
module_sizes = dict(
189+
sorted(
190+
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
191+
key=lambda x: x[1],
192+
reverse=True,
193+
)
194+
)
195+
196+
def search_best_candidate(module_sizes, min_memory_offload):
197+
"""
198+
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
199+
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
200+
larger than `min_memory_offload`
201+
"""
202+
model_ids = list(module_sizes.keys())
203+
best_candidate = None
204+
best_size = float("inf")
205+
for r in range(1, len(model_ids) + 1):
206+
for candidate_model_ids in combinations(model_ids, r):
207+
candidate_size = sum(
208+
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
209+
)
210+
if candidate_size < min_memory_offload:
211+
continue
212+
else:
213+
if best_candidate is None or candidate_size < best_size:
214+
best_candidate = candidate_model_ids
215+
best_size = candidate_size
216+
217+
return best_candidate
218+
219+
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
220+
221+
if best_offload_model_ids is None:
222+
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
223+
logger.warning("no combination of models to offload to cpu is found, offloading all models")
224+
hooks_to_offload = hooks
225+
else:
226+
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
227+
228+
return hooks_to_offload
229+
230+
231+
class ComponentsManager:
232+
def __init__(self):
233+
self.components = OrderedDict()
234+
self.model_hooks = None
235+
self._auto_offload_enabled = False
236+
237+
def add(self, name, component):
238+
if name not in self.components:
239+
self.components[name] = component
240+
if self._auto_offload_enabled:
241+
self.enable_auto_cpu_offload(self._auto_offload_device)
242+
243+
def remove(self, name):
244+
self.components.pop(name)
245+
if self._auto_offload_enabled:
246+
self.enable_auto_cpu_offload(self._auto_offload_device)
247+
248+
def get(self, names: Union[str, List[str]]):
249+
if isinstance(names, str):
250+
if names not in self.components:
251+
raise ValueError(f"Component '{names}' not found in ComponentsManager")
252+
return self.components[names]
253+
elif isinstance(names, list):
254+
return {n: self.components[n] for n in names}
255+
else:
256+
raise ValueError(f"Invalid type for names: {type(names)}")
257+
258+
def enable_auto_cpu_offload(self, device, memory_reserve_margin="3GB"):
259+
for name, component in self.components.items():
260+
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
261+
remove_hook_from_module(component, recurse=True)
262+
263+
self.disable_auto_cpu_offload()
264+
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
265+
device = torch.device(device)
266+
if device.index is None:
267+
device = torch.device(f"{device.type}:{0}")
268+
all_hooks = []
269+
for name, component in self.components.items():
270+
if isinstance(component, torch.nn.Module):
271+
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
272+
all_hooks.append(hook)
273+
274+
for hook in all_hooks:
275+
other_hooks = [h for h in all_hooks if h is not hook]
276+
for other_hook in other_hooks:
277+
if other_hook.hook.execution_device == hook.hook.execution_device:
278+
hook.add_other_hook(other_hook)
279+
280+
self.model_hooks = all_hooks
281+
self._auto_offload_enabled = True
282+
self._auto_offload_device = device
283+
284+
def disable_auto_cpu_offload(self):
285+
if self.model_hooks is None:
286+
self._auto_offload_enabled = False
287+
return
288+
289+
for hook in self.model_hooks:
290+
hook.offload()
291+
hook.remove()
292+
if self.model_hooks:
293+
clear_device_cache()
294+
self.model_hooks = None
295+
self._auto_offload_enabled = False
296+
297+
def __repr__(self):
298+
col_widths = {
299+
"id": max(15, max(len(id) for id in self.components.keys())),
300+
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
301+
"device": 10,
302+
"dtype": 15,
303+
"size": 10,
304+
}
305+
306+
# Create the header lines
307+
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
308+
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
309+
310+
output = "Components:\n" + sep_line
311+
312+
# Separate components into models and others
313+
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
314+
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
315+
316+
# Models section
317+
if models:
318+
output += "Models:\n" + dash_line
319+
# Column headers
320+
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
321+
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n"
322+
output += dash_line
323+
324+
# Model entries
325+
for name, component in models.items():
326+
device = component.device
327+
dtype = component.dtype
328+
size_bytes = get_memory_footprint(component)
329+
size_gb = size_bytes / (1024**3)
330+
331+
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | "
332+
output += (
333+
f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n"
334+
)
335+
output += dash_line
336+
337+
# Other components section
338+
if others:
339+
if models: # Add extra newline if we had models section
340+
output += "\n"
341+
output += "Other Components:\n" + dash_line
342+
# Column headers for other components
343+
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
344+
output += dash_line
345+
346+
# Other component entries
347+
for name, component in others.items():
348+
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
349+
output += dash_line
350+
351+
return output
352+
353+
def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs):
354+
from ..pipelines.pipeline_utils import DiffusionPipeline
355+
356+
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
357+
for name, component in pipe.components.items():
358+
if name not in self.components and component is not None:
359+
self.add(name, component)
360+
elif name in self.components:
361+
logger.warning(
362+
f"Component '{name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
363+
f"1. remove the existing component with remove('{name}')\n"
364+
f"2. Use a different name: add('{name}_2', component)"
365+
)

0 commit comments

Comments
 (0)