Skip to content

Commit 15f98db

Browse files
committed
update
1 parent f36ba9f commit 15f98db

File tree

1 file changed

+58
-80
lines changed

1 file changed

+58
-80
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 58 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
self.offload_to_disk_path = offload_to_disk_path
102102
self._is_offloaded_to_disk = False
103103

104-
if self.offload_to_disk_path:
104+
if self.offload_to_disk_path is not None:
105105
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
106106
self.group_id = group_id if group_id is not None else str(id(self))
107107
short_hash = _compute_group_hash(self.group_id)
@@ -121,6 +121,12 @@ def __init__(
121121
else:
122122
self.cpu_param_dict = self._init_cpu_param_dict()
123123

124+
self._torch_accelerator_module = (
125+
getattr(torch, torch.accelerator.current_accelerator().type)
126+
if hasattr(torch, "accelerator")
127+
else torch.cuda
128+
)
129+
124130
def _init_cpu_param_dict(self):
125131
cpu_param_dict = {}
126132
if self.stream is None:
@@ -144,16 +150,12 @@ def _init_cpu_param_dict(self):
144150

145151
@contextmanager
146152
def _pinned_memory_tensors(self):
147-
pinned_dict = {}
148153
try:
149-
for param, tensor in self.cpu_param_dict.items():
150-
if not tensor.is_pinned():
151-
pinned_dict[param] = tensor.pin_memory()
152-
else:
153-
pinned_dict[param] = tensor
154-
154+
pinned_dict = {
155+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
156+
for param, tensor in self.cpu_param_dict.items()
157+
}
155158
yield pinned_dict
156-
157159
finally:
158160
pinned_dict = None
159161

@@ -179,77 +181,47 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
179181
source = pinned_memory[buffer] if pinned_memory else buffer.data
180182
self._transfer_tensor_to_device(buffer, source, current_stream)
181183

182-
def _onload_from_disk(self, current_stream):
184+
def _onload_from_disk(self):
183185
if self.stream is not None:
184-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
185-
186-
for key, tensor_obj in self.key_to_tensor.items():
187-
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
188-
189-
with self._pinned_memory_tensors() as pinned_memory:
190-
for key, tensor_obj in self.key_to_tensor.items():
191-
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
192-
193-
self.cpu_param_dict.clear()
194-
195-
else:
196-
onload_device = (
197-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
198-
)
199-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
200-
for key, tensor_obj in self.key_to_tensor.items():
201-
tensor_obj.data = loaded_tensors[key]
186+
# Wait for previous Host->Device transfer to complete
187+
self.stream.synchronize()
202188

203-
def _onload_from_memory(self, current_stream):
204-
if self.stream is not None:
205-
with self._pinned_memory_tensors() as pinned_memory:
206-
self._process_tensors_from_modules(pinned_memory, current_stream)
207-
else:
208-
self._process_tensors_from_modules(None, current_stream)
189+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
190+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
209191

210-
@torch.compiler.disable()
211-
def onload_(self):
212-
torch_accelerator_module = (
213-
getattr(torch, torch.accelerator.current_accelerator().type)
214-
if hasattr(torch, "accelerator")
215-
else torch.cuda
216-
)
217-
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
218-
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
192+
with context:
193+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
194+
device = self.onload_device if self.stream is None else "cpu"
195+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
219196

220-
if self.offload_to_disk_path:
221197
if self.stream is not None:
222-
# Wait for previous Host->Device transfer to complete
223-
self.stream.synchronize()
224-
225-
with context:
226-
if self.stream is not None:
227-
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
228-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
229-
for key, tensor_obj in self.key_to_tensor.items():
230-
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
231-
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
232-
if self.record_stream:
233-
tensor_obj.data.record_stream(current_stream)
234-
else:
235-
# Load directly to the target device (synchronous)
236-
onload_device = (
237-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
238-
)
239-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
240-
for key, tensor_obj in self.key_to_tensor.items():
241-
tensor_obj.data = loaded_tensors[key]
242-
return
198+
for key, tensor_obj in self.key_to_tensor.items():
199+
pinned_tensor = loaded_tensors[key].pin_memory()
200+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
201+
if self.record_stream:
202+
tensor_obj.data.record_stream(current_stream)
203+
else:
204+
onload_device = (
205+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
206+
)
207+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
208+
for key, tensor_obj in self.key_to_tensor.items():
209+
tensor_obj.data = loaded_tensors[key]
243210

211+
def _onload_from_memory(self):
244212
if self.stream is not None:
245213
# Wait for previous Host->Device transfer to complete
246214
self.stream.synchronize()
247215

216+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
217+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
218+
248219
with context:
249-
if self.offload_to_disk_path:
250-
self._onload_from_disk(current_stream)
220+
if self.stream is not None:
221+
with self._pinned_memory_tensors() as pinned_memory:
222+
self._process_tensors_from_modules(pinned_memory, current_stream)
251223
else:
252-
self._onload_from_memory(current_stream)
224+
self._process_tensors_from_modules(None, current_stream)
253225

254226
def _offload_to_disk(self):
255227
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -270,14 +242,10 @@ def _offload_to_disk(self):
270242
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
271243

272244
def _offload_to_memory(self):
273-
torch_accelerator_module = (
274-
getattr(torch, torch.accelerator.current_accelerator().type)
275-
if hasattr(torch, "accelerator")
276-
else torch.cuda
277-
)
278245
if self.stream is not None:
279246
if not self.record_stream:
280-
torch_accelerator_module.current_stream().synchronize()
247+
self._torch_accelerator_module.current_stream().synchronize()
248+
281249
for group_module in self.modules:
282250
for param in group_module.parameters():
283251
param.data = self.cpu_param_dict[param]
@@ -288,15 +256,23 @@ def _offload_to_memory(self):
288256

289257
else:
290258
for group_module in self.modules:
291-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
259+
group_module.to(self.offload_device, non_blocking=False)
292260
for param in self.parameters:
293-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
261+
param.data = param.data.to(self.offload_device, non_blocking=False)
294262
for buffer in self.buffers:
295-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
263+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
264+
265+
@torch.compiler.disable()
266+
def onload_(self):
267+
r"""Onloads the group of parameters to the onload_device."""
268+
if self.offload_to_disk_path is not None:
269+
self._onload_from_disk()
270+
else:
271+
self._onload_from_memory()
296272

297273
@torch.compiler.disable()
298274
def offload_(self):
299-
r"""Offloads the group of modules to the offload_device."""
275+
r"""Offloads the group of parameters to the offload_device."""
300276
if self.offload_to_disk_path:
301277
self._offload_to_disk()
302278
else:
@@ -462,8 +438,8 @@ def pre_forward(self, module, *args, **kwargs):
462438

463439
def apply_group_offloading(
464440
module: torch.nn.Module,
465-
onload_device: torch.device,
466-
offload_device: torch.device = torch.device("cpu"),
441+
onload_device: Union[str, torch.device],
442+
offload_device: Union[str, torch.device] = torch.device("cpu"),
467443
offload_type: Union[str, GroupOffloadingType] = "block_level",
468444
num_blocks_per_group: Optional[int] = None,
469445
non_blocking: bool = False,
@@ -549,6 +525,8 @@ def apply_group_offloading(
549525
```
550526
"""
551527

528+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
529+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
552530
offload_type = GroupOffloadingType(offload_type)
553531

554532
stream = None

0 commit comments

Comments
 (0)