Skip to content

Commit 8963162

Browse files
authored
Merge branch 'main' into support-diffusers-ckpt-gguf
2 parents 3f67ed0 + 69cdc25 commit 8963162

File tree

3 files changed

+177
-105
lines changed

3 files changed

+177
-105
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 89 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.offload_to_disk_path = offload_to_disk_path
9696
self._is_offloaded_to_disk = False
9797

98-
if self.offload_to_disk_path:
98+
if self.offload_to_disk_path is not None:
9999
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100100
self.group_id = group_id if group_id is not None else str(id(self))
101101
short_hash = _compute_group_hash(self.group_id)
@@ -115,6 +115,12 @@ def __init__(
115115
else:
116116
self.cpu_param_dict = self._init_cpu_param_dict()
117117

118+
self._torch_accelerator_module = (
119+
getattr(torch, torch.accelerator.current_accelerator().type)
120+
if hasattr(torch, "accelerator")
121+
else torch.cuda
122+
)
123+
118124
def _init_cpu_param_dict(self):
119125
cpu_param_dict = {}
120126
if self.stream is None:
@@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):
138144

139145
@contextmanager
140146
def _pinned_memory_tensors(self):
141-
pinned_dict = {}
142147
try:
143-
for param, tensor in self.cpu_param_dict.items():
144-
if not tensor.is_pinned():
145-
pinned_dict[param] = tensor.pin_memory()
146-
else:
147-
pinned_dict[param] = tensor
148-
148+
pinned_dict = {
149+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
150+
for param, tensor in self.cpu_param_dict.items()
151+
}
149152
yield pinned_dict
150-
151153
finally:
152154
pinned_dict = None
153155

154-
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
156+
def _transfer_tensor_to_device(self, tensor, source_tensor):
155157
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
156-
if self.record_stream and current_stream is not None:
157-
tensor.data.record_stream(current_stream)
158+
if self.record_stream:
159+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
158160

159-
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
161+
def _process_tensors_from_modules(self, pinned_memory=None):
160162
for group_module in self.modules:
161163
for param in group_module.parameters():
162164
source = pinned_memory[param] if pinned_memory else param.data
163-
self._transfer_tensor_to_device(param, source, current_stream)
165+
self._transfer_tensor_to_device(param, source)
164166
for buffer in group_module.buffers():
165167
source = pinned_memory[buffer] if pinned_memory else buffer.data
166-
self._transfer_tensor_to_device(buffer, source, current_stream)
168+
self._transfer_tensor_to_device(buffer, source)
167169

168170
for param in self.parameters:
169171
source = pinned_memory[param] if pinned_memory else param.data
170-
self._transfer_tensor_to_device(param, source, current_stream)
172+
self._transfer_tensor_to_device(param, source)
171173

172174
for buffer in self.buffers:
173175
source = pinned_memory[buffer] if pinned_memory else buffer.data
174-
self._transfer_tensor_to_device(buffer, source, current_stream)
176+
self._transfer_tensor_to_device(buffer, source)
175177

176-
def _onload_from_disk(self, current_stream):
178+
def _onload_from_disk(self):
177179
if self.stream is not None:
178-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
179-
180-
for key, tensor_obj in self.key_to_tensor.items():
181-
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
182-
183-
with self._pinned_memory_tensors() as pinned_memory:
184-
for key, tensor_obj in self.key_to_tensor.items():
185-
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
186-
187-
self.cpu_param_dict.clear()
188-
189-
else:
190-
onload_device = (
191-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
192-
)
193-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
194-
for key, tensor_obj in self.key_to_tensor.items():
195-
tensor_obj.data = loaded_tensors[key]
180+
# Wait for previous Host->Device transfer to complete
181+
self.stream.synchronize()
196182

197-
def _onload_from_memory(self, current_stream):
198-
if self.stream is not None:
199-
with self._pinned_memory_tensors() as pinned_memory:
200-
self._process_tensors_from_modules(pinned_memory, current_stream)
201-
else:
202-
self._process_tensors_from_modules(None, current_stream)
183+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
184+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
203185

204-
@torch.compiler.disable()
205-
def onload_(self):
206-
torch_accelerator_module = (
207-
getattr(torch, torch.accelerator.current_accelerator().type)
208-
if hasattr(torch, "accelerator")
209-
else torch.cuda
210-
)
211-
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
212-
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
186+
with context:
187+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188+
device = str(self.onload_device) if self.stream is None else "cpu"
189+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
213190

214-
if self.offload_to_disk_path:
215191
if self.stream is not None:
216-
# Wait for previous Host->Device transfer to complete
217-
self.stream.synchronize()
218-
219-
with context:
220-
if self.stream is not None:
221-
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
222-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
223-
for key, tensor_obj in self.key_to_tensor.items():
224-
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
225-
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
226-
if self.record_stream:
227-
tensor_obj.data.record_stream(current_stream)
228-
else:
229-
# Load directly to the target device (synchronous)
230-
onload_device = (
231-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
232-
)
233-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
234-
for key, tensor_obj in self.key_to_tensor.items():
235-
tensor_obj.data = loaded_tensors[key]
236-
return
192+
for key, tensor_obj in self.key_to_tensor.items():
193+
pinned_tensor = loaded_tensors[key].pin_memory()
194+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
195+
if self.record_stream:
196+
tensor_obj.data.record_stream(current_stream)
197+
else:
198+
onload_device = (
199+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
200+
)
201+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
202+
for key, tensor_obj in self.key_to_tensor.items():
203+
tensor_obj.data = loaded_tensors[key]
237204

205+
def _onload_from_memory(self):
238206
if self.stream is not None:
239207
# Wait for previous Host->Device transfer to complete
240208
self.stream.synchronize()
241209

210+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
242211
with context:
243-
if self.offload_to_disk_path:
244-
self._onload_from_disk(current_stream)
212+
if self.stream is not None:
213+
with self._pinned_memory_tensors() as pinned_memory:
214+
self._process_tensors_from_modules(pinned_memory)
245215
else:
246-
self._onload_from_memory(current_stream)
216+
self._process_tensors_from_modules(None)
247217

248218
def _offload_to_disk(self):
249219
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,33 +234,36 @@ def _offload_to_disk(self):
264234
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
265235

266236
def _offload_to_memory(self):
267-
torch_accelerator_module = (
268-
getattr(torch, torch.accelerator.current_accelerator().type)
269-
if hasattr(torch, "accelerator")
270-
else torch.cuda
271-
)
272237
if self.stream is not None:
273238
if not self.record_stream:
274-
torch_accelerator_module.current_stream().synchronize()
239+
self._torch_accelerator_module.current_stream().synchronize()
240+
275241
for group_module in self.modules:
276242
for param in group_module.parameters():
277243
param.data = self.cpu_param_dict[param]
278244
for param in self.parameters:
279245
param.data = self.cpu_param_dict[param]
280246
for buffer in self.buffers:
281247
buffer.data = self.cpu_param_dict[buffer]
282-
283248
else:
284249
for group_module in self.modules:
285-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
250+
group_module.to(self.offload_device, non_blocking=False)
286251
for param in self.parameters:
287-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
252+
param.data = param.data.to(self.offload_device, non_blocking=False)
288253
for buffer in self.buffers:
289-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
254+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
255+
256+
@torch.compiler.disable()
257+
def onload_(self):
258+
r"""Onloads the group of parameters to the onload_device."""
259+
if self.offload_to_disk_path is not None:
260+
self._onload_from_disk()
261+
else:
262+
self._onload_from_memory()
290263

291264
@torch.compiler.disable()
292265
def offload_(self):
293-
r"""Offloads the group of modules to the offload_device."""
266+
r"""Offloads the group of parameters to the offload_device."""
294267
if self.offload_to_disk_path:
295268
self._offload_to_disk()
296269
else:
@@ -307,11 +280,9 @@ class GroupOffloadingHook(ModelHook):
307280

308281
_is_stateful = False
309282

310-
def __init__(
311-
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
312-
) -> None:
283+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
313284
self.group = group
314-
self.next_group = next_group
285+
self.next_group: Optional[ModuleGroup] = None
315286
self.config = config
316287

317288
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -331,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
331302
if self.group.onload_leader == module:
332303
if self.group.onload_self:
333304
self.group.onload_()
334-
if self.next_group is not None and not self.next_group.onload_self:
305+
306+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307+
if should_onload_next_group:
335308
self.next_group.onload_()
336309

310+
should_synchronize = (
311+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312+
)
313+
if should_synchronize:
314+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
315+
# previous group. We need to synchronize the side stream to ensure parameters
316+
# are completely loaded to proceed with forward pass. Without this, uninitialized
317+
# weights will be used in the computation, leading to incorrect results
318+
# Also, we should only do this synchronization if we don't already do it from the sync call in
319+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
320+
self.group.stream.synchronize()
321+
337322
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
338323
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
339324
return args, kwargs
@@ -459,8 +444,8 @@ def pre_forward(self, module, *args, **kwargs):
459444

460445
def apply_group_offloading(
461446
module: torch.nn.Module,
462-
onload_device: torch.device,
463-
offload_device: torch.device = torch.device("cpu"),
447+
onload_device: Union[str, torch.device],
448+
offload_device: Union[str, torch.device] = torch.device("cpu"),
464449
offload_type: Union[str, GroupOffloadingType] = "block_level",
465450
num_blocks_per_group: Optional[int] = None,
466451
non_blocking: bool = False,
@@ -546,6 +531,8 @@ def apply_group_offloading(
546531
```
547532
"""
548533

534+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
535+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
549536
offload_type = GroupOffloadingType(offload_type)
550537

551538
stream = None
@@ -633,7 +620,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
633620
# Apply group offloading hooks to the module groups
634621
for i, group in enumerate(matched_module_groups):
635622
for group_module in group.modules:
636-
_apply_group_offloading_hook(group_module, group, None, config=config)
623+
_apply_group_offloading_hook(group_module, group, config=config)
637624

638625
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
639626
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +649,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
662649
group_id=f"{module.__class__.__name__}_unmatched_group",
663650
)
664651
if config.stream is None:
665-
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
652+
_apply_group_offloading_hook(module, unmatched_group, config=config)
666653
else:
667-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
654+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
668655

669656

670657
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -693,7 +680,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
693680
onload_self=True,
694681
group_id=name,
695682
)
696-
_apply_group_offloading_hook(submodule, group, None, config=config)
683+
_apply_group_offloading_hook(submodule, group, config=config)
697684
modules_with_group_offloading.add(name)
698685

699686
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +727,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
740727
onload_self=True,
741728
group_id=name,
742729
)
743-
_apply_group_offloading_hook(parent_module, group, None, config=config)
730+
_apply_group_offloading_hook(parent_module, group, config=config)
744731

745732
if config.stream is not None:
746733
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +749,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
762749
onload_self=True,
763750
group_id=_GROUP_ID_LAZY_LEAF,
764751
)
765-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
752+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
766753

767754

768755
def _apply_group_offloading_hook(
769756
module: torch.nn.Module,
770757
group: ModuleGroup,
771-
next_group: Optional[ModuleGroup] = None,
772758
*,
773759
config: GroupOffloadingConfig,
774760
) -> None:
@@ -777,14 +763,13 @@ def _apply_group_offloading_hook(
777763
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
778764
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
779765
if registry.get_hook(_GROUP_OFFLOADING) is None:
780-
hook = GroupOffloadingHook(group, next_group, config=config)
766+
hook = GroupOffloadingHook(group, config=config)
781767
registry.register_hook(hook, _GROUP_OFFLOADING)
782768

783769

784770
def _apply_lazy_group_offloading_hook(
785771
module: torch.nn.Module,
786772
group: ModuleGroup,
787-
next_group: Optional[ModuleGroup] = None,
788773
*,
789774
config: GroupOffloadingConfig,
790775
) -> None:
@@ -793,7 +778,7 @@ def _apply_lazy_group_offloading_hook(
793778
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
794779
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
795780
if registry.get_hook(_GROUP_OFFLOADING) is None:
796-
hook = GroupOffloadingHook(group, next_group, config=config)
781+
hook = GroupOffloadingHook(group, config=config)
797782
registry.register_hook(hook, _GROUP_OFFLOADING)
798783

799784
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()

0 commit comments

Comments
 (0)