Skip to content

Commit e0d5079

Browse files
committed
start implementing disk offloading in group.
1 parent 7c6e9ef commit e0d5079

File tree

2 files changed

+313
-4
lines changed

2 files changed

+313
-4
lines changed

go.diff

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
diff --git a/diffusers/hooks/offload.py b/diffusers/hooks/offload.py
2+
--- a/diffusers/hooks/offload.py
3+
+++ b/diffusers/hooks/offload.py
4+
@@ -1,6 +1,10 @@
5+
import os
6+
-import torch
7+
+import torch
8+
+from safetensors.torch import save_file, load_file
9+
10+
+import os
11+
from typing import Optional, Union
12+
from torch import nn
13+
from .module_group import ModuleGroup
14+
@@ -25,6 +29,32 @@ from .hooks import HookRegistry
15+
from .hooks import GroupOffloadingHook, LazyPrefetchGroupOffloadingHook
16+
17+
+# -------------------------------------------------------------------------------
18+
+# Helpers for disk/NVMe offload using safetensors
19+
+# -------------------------------------------------------------------------------
20+
+def _offload_tensor_to_disk_st(tensor: torch.Tensor, path: str) -> None:
21+
+ """
22+
+ Serialize a tensor out to disk in safetensors format.
23+
+ We pin the CPU copy so that non_blocking loads can overlap copy/compute.
24+
+ """
25+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
26+
+ cpu_t = tensor.detach().cpu().pin_memory()
27+
+ save_file({"0": cpu_t}, path)
28+
+ # free the original GPU tensor immediately
29+
+ del tensor
30+
+
31+
+def _load_tensor_from_disk_st(
32+
+ path: str, device: torch.device, non_blocking: bool
33+
+) -> torch.Tensor:
34+
+ """
35+
+ Load a tensor back in with safetensors.
36+
+ - If non_blocking on CUDA: load to CPU pinned memory, then .to(cuda, non_blocking=True).
37+
+ - Otherwise: direct load_file(device=...).
38+
+ """
39+
+ # fast path: direct to target device
40+
+ if not (non_blocking and device.type == "cuda"):
41+
+ data = load_file(path, device=device)
42+
+ return data["0"]
43+
+ # pinned-CPU fallback for true non-blocking
44+
+ data = load_file(path, device="cpu")
45+
+ cpu_t = data["0"]
46+
+ return cpu_t.to(device, non_blocking=True)
47+
+
48+
+
49+
def apply_group_offloading(
50+
module: torch.nn.Module,
51+
onload_device: torch.device,
52+
- offload_device: torch.device = torch.device("cpu"),
53+
- offload_type: str = "block_level",
54+
+ offload_device: torch.device = torch.device("cpu"),
55+
+ *,
56+
+ offload_to_disk: bool = False,
57+
+ offload_path: Optional[str] = None,
58+
+ offload_type: str = "block_level",
59+
num_blocks_per_group: Optional[int] = None,
60+
non_blocking: bool = False,
61+
use_stream: bool = False,
62+
@@ -37,6 +67,10 @@ def apply_group_offloading(
63+
Example:
64+
```python
65+
>>> apply_group_offloading(... )
66+
+ # to store params on NVMe:
67+
+ >>> apply_group_offloading(
68+
+ ... model,
69+
+ ... onload_device=torch.device("cuda"),
70+
+ ... offload_to_disk=True,
71+
+ ... offload_path="/mnt/nvme1/offload",
72+
+ ... offload_type="block_level",
73+
+ ... num_blocks_per_group=1,
74+
+ ... )
75+
```
76+
"""
77+
78+
@@ -69,6 +103,10 @@ def apply_group_offloading(
79+
if num_blocks_per_group is None:
80+
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
81+
+ if offload_to_disk and offload_path is None:
82+
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
83+
84+
_apply_group_offloading_block_level(
85+
module=module,
86+
+ offload_to_disk=offload_to_disk,
87+
+ offload_path=offload_path,
88+
num_blocks_per_group=num_blocks_per_group,
89+
offload_device=offload_device,
90+
onload_device=onload_device,
91+
@@ -79,6 +117,11 @@ def apply_group_offloading(
92+
elif offload_type == "leaf_level":
93+
+ if offload_to_disk and offload_path is None:
94+
+ raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
95+
_apply_group_offloading_leaf_level(
96+
module=module,
97+
+ offload_to_disk=offload_to_disk,
98+
+ offload_path=offload_path,
99+
offload_device=offload_device,
100+
onload_device=onload_device,
101+
non_blocking=non_blocking,
102+
@@ -107,10 +150,16 @@ def _apply_group_offloading_block_level(
103+
"""
104+
- module: torch.nn.Module,
105+
- num_blocks_per_group: int,
106+
- offload_device: torch.device,
107+
- onload_device: torch.device,
108+
+ module: torch.nn.Module,
109+
+ num_blocks_per_group: int,
110+
+ offload_device: torch.device,
111+
+ offload_to_disk: bool,
112+
+ offload_path: Optional[str],
113+
+ onload_device: torch.device,
114+
non_blocking: bool,
115+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
116+
record_stream: Optional[bool] = False,
117+
low_cpu_mem_usage: bool = False,
118+
) -> None:
119+
@@ -138,7 +187,9 @@ def _apply_group_offloading_block_level(
120+
for i in range(0, len(submodule), num_blocks_per_group):
121+
current_modules = submodule[i : i + num_blocks_per_group]
122+
group = ModuleGroup(
123+
- modules=current_modules,
124+
+ modules=current_modules,
125+
+ offload_to_disk=offload_to_disk,
126+
+ offload_path=offload_path,
127+
offload_device=offload_device,
128+
onload_device=onload_device,
129+
offload_leader=current_modules[-1],
130+
@@ -187,10 +238,14 @@ def _apply_group_offloading_block_level(
131+
unmatched_group = ModuleGroup(
132+
modules=unmatched_modules,
133+
- offload_device=offload_device,
134+
+ offload_to_disk=offload_to_disk,
135+
+ offload_path=offload_path,
136+
+ offload_device=offload_device,
137+
onload_device=onload_device,
138+
offload_leader=module,
139+
onload_leader=module,
140+
+ # other args omitted for brevity...
141+
)
142+
143+
if stream is None:
144+
@@ -216,10 +271,16 @@ def _apply_group_offloading_leaf_level(
145+
"""
146+
- module: torch.nn.Module,
147+
- offload_device: torch.device,
148+
- onload_device: torch.device,
149+
- non_blocking: bool,
150+
+ module: torch.nn.Module,
151+
+ offload_device: torch.device,
152+
+ offload_to_disk: bool,
153+
+ offload_path: Optional[str],
154+
+ onload_device: torch.device,
155+
+ non_blocking: bool,
156+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
157+
record_stream: Optional[bool] = False,
158+
low_cpu_mem_usage: bool = False,
159+
) -> None:
160+
@@ -229,7 +290,9 @@ def _apply_group_offloading_leaf_level(
161+
for name, submodule in module.named_modules():
162+
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
163+
continue
164+
- group = ModuleGroup(
165+
+ group = ModuleGroup(
166+
+ offload_to_disk=offload_to_disk,
167+
+ offload_path=offload_path,
168+
modules=[submodule],
169+
offload_device=offload_device,
170+
onload_device=onload_device,
171+
@@ -317,10 +380,14 @@ def _apply_group_offloading_leaf_level(
172+
parent_module = module_dict[name]
173+
assert getattr(parent_module, "_diffusers_hook", None) is None
174+
- group = ModuleGroup(
175+
+ group = ModuleGroup(
176+
+ offload_to_disk=offload_to_disk,
177+
+ offload_path=offload_path,
178+
modules=[],
179+
offload_device=offload_device,
180+
onload_device=onload_device,
181+
+ # additional args omitted for brevity...
182+
)
183+
_apply_group_offloading_hook(parent_module, group, None)
184+
185+
@@ -360,6 +427,38 @@ def _apply_lazy_group_offloading_hook(
186+
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
187+
188+
189+
+# -------------------------------------------------------------------------------
190+
+# Patch GroupOffloadingHook to use safetensors disk offload
191+
+# -------------------------------------------------------------------------------
192+
+class GroupOffloadingHook:
193+
+ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup]):
194+
+ self.group = group
195+
+ self.next_group = next_group
196+
+ # map param/buffer name -> file path
197+
+ self.param_to_path: Dict[str,str] = {}
198+
+ self.buffer_to_path: Dict[str,str] = {}
199+
+
200+
+ def offload_parameters(self, module: nn.Module):
201+
+ for name, param in module.named_parameters(recurse=False):
202+
+ if self.group.offload_to_disk:
203+
+ path = os.path.join(self.group.offload_path, f"{module.__class__.__name__}__{name}.safetensors")
204+
+ _offload_tensor_to_disk_st(param.data, path)
205+
+ self.param_to_path[name] = path
206+
+ else:
207+
+ param.data = param.data.to(self.group.offload_device, non_blocking=self.group.non_blocking)
208+
+
209+
+ def onload_parameters(self, module: nn.Module):
210+
+ for name, param in module.named_parameters(recurse=False):
211+
+ if self.group.offload_to_disk:
212+
+ path = self.param_to_path[name]
213+
+ param.data = _load_tensor_from_disk_st(path, self.group.onload_device, self.group.non_blocking)
214+
+ else:
215+
+ param.data = param.data.to(self.group.onload_device, non_blocking=self.group.non_blocking)
216+
+
217+
+ # analogous changes for buffers...
218+
+

0 commit comments

Comments
 (0)