Skip to content

Commit cfd6ec7

Browse files
authored
[refactor] condense group offloading (#11990)
* update * update * refactor * add test * address review comment * nit
1 parent 1082c46 commit cfd6ec7

File tree

2 files changed

+161
-102
lines changed

2 files changed

+161
-102
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 74 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.offload_to_disk_path = offload_to_disk_path
9696
self._is_offloaded_to_disk = False
9797

98-
if self.offload_to_disk_path:
98+
if self.offload_to_disk_path is not None:
9999
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100100
self.group_id = group_id if group_id is not None else str(id(self))
101101
short_hash = _compute_group_hash(self.group_id)
@@ -115,6 +115,12 @@ def __init__(
115115
else:
116116
self.cpu_param_dict = self._init_cpu_param_dict()
117117

118+
self._torch_accelerator_module = (
119+
getattr(torch, torch.accelerator.current_accelerator().type)
120+
if hasattr(torch, "accelerator")
121+
else torch.cuda
122+
)
123+
118124
def _init_cpu_param_dict(self):
119125
cpu_param_dict = {}
120126
if self.stream is None:
@@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):
138144

139145
@contextmanager
140146
def _pinned_memory_tensors(self):
141-
pinned_dict = {}
142147
try:
143-
for param, tensor in self.cpu_param_dict.items():
144-
if not tensor.is_pinned():
145-
pinned_dict[param] = tensor.pin_memory()
146-
else:
147-
pinned_dict[param] = tensor
148-
148+
pinned_dict = {
149+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
150+
for param, tensor in self.cpu_param_dict.items()
151+
}
149152
yield pinned_dict
150-
151153
finally:
152154
pinned_dict = None
153155

154-
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
156+
def _transfer_tensor_to_device(self, tensor, source_tensor):
155157
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
156-
if self.record_stream and current_stream is not None:
157-
tensor.data.record_stream(current_stream)
158+
if self.record_stream:
159+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
158160

159-
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
161+
def _process_tensors_from_modules(self, pinned_memory=None):
160162
for group_module in self.modules:
161163
for param in group_module.parameters():
162164
source = pinned_memory[param] if pinned_memory else param.data
163-
self._transfer_tensor_to_device(param, source, current_stream)
165+
self._transfer_tensor_to_device(param, source)
164166
for buffer in group_module.buffers():
165167
source = pinned_memory[buffer] if pinned_memory else buffer.data
166-
self._transfer_tensor_to_device(buffer, source, current_stream)
168+
self._transfer_tensor_to_device(buffer, source)
167169

168170
for param in self.parameters:
169171
source = pinned_memory[param] if pinned_memory else param.data
170-
self._transfer_tensor_to_device(param, source, current_stream)
172+
self._transfer_tensor_to_device(param, source)
171173

172174
for buffer in self.buffers:
173175
source = pinned_memory[buffer] if pinned_memory else buffer.data
174-
self._transfer_tensor_to_device(buffer, source, current_stream)
176+
self._transfer_tensor_to_device(buffer, source)
175177

176-
def _onload_from_disk(self, current_stream):
178+
def _onload_from_disk(self):
177179
if self.stream is not None:
178-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
179-
180-
for key, tensor_obj in self.key_to_tensor.items():
181-
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
182-
183-
with self._pinned_memory_tensors() as pinned_memory:
184-
for key, tensor_obj in self.key_to_tensor.items():
185-
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
186-
187-
self.cpu_param_dict.clear()
180+
# Wait for previous Host->Device transfer to complete
181+
self.stream.synchronize()
188182

189-
else:
190-
onload_device = (
191-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
192-
)
193-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
194-
for key, tensor_obj in self.key_to_tensor.items():
195-
tensor_obj.data = loaded_tensors[key]
183+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
184+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
196185

197-
def _onload_from_memory(self, current_stream):
198-
if self.stream is not None:
199-
with self._pinned_memory_tensors() as pinned_memory:
200-
self._process_tensors_from_modules(pinned_memory, current_stream)
201-
else:
202-
self._process_tensors_from_modules(None, current_stream)
203-
204-
@torch.compiler.disable()
205-
def onload_(self):
206-
torch_accelerator_module = (
207-
getattr(torch, torch.accelerator.current_accelerator().type)
208-
if hasattr(torch, "accelerator")
209-
else torch.cuda
210-
)
211-
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
212-
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
186+
with context:
187+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188+
device = str(self.onload_device) if self.stream is None else "cpu"
189+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
213190

214-
if self.offload_to_disk_path:
215191
if self.stream is not None:
216-
# Wait for previous Host->Device transfer to complete
217-
self.stream.synchronize()
218-
219-
with context:
220-
if self.stream is not None:
221-
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
222-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
223-
for key, tensor_obj in self.key_to_tensor.items():
224-
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
225-
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
226-
if self.record_stream:
227-
tensor_obj.data.record_stream(current_stream)
228-
else:
229-
# Load directly to the target device (synchronous)
230-
onload_device = (
231-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
232-
)
233-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
234-
for key, tensor_obj in self.key_to_tensor.items():
235-
tensor_obj.data = loaded_tensors[key]
236-
return
192+
for key, tensor_obj in self.key_to_tensor.items():
193+
pinned_tensor = loaded_tensors[key].pin_memory()
194+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
195+
if self.record_stream:
196+
tensor_obj.data.record_stream(current_stream)
197+
else:
198+
onload_device = (
199+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
200+
)
201+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
202+
for key, tensor_obj in self.key_to_tensor.items():
203+
tensor_obj.data = loaded_tensors[key]
237204

205+
def _onload_from_memory(self):
238206
if self.stream is not None:
239207
# Wait for previous Host->Device transfer to complete
240208
self.stream.synchronize()
241209

210+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
242211
with context:
243-
if self.offload_to_disk_path:
244-
self._onload_from_disk(current_stream)
212+
if self.stream is not None:
213+
with self._pinned_memory_tensors() as pinned_memory:
214+
self._process_tensors_from_modules(pinned_memory)
245215
else:
246-
self._onload_from_memory(current_stream)
216+
self._process_tensors_from_modules(None)
247217

248218
def _offload_to_disk(self):
249219
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ def _offload_to_disk(self):
264234
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
265235

266236
def _offload_to_memory(self):
267-
torch_accelerator_module = (
268-
getattr(torch, torch.accelerator.current_accelerator().type)
269-
if hasattr(torch, "accelerator")
270-
else torch.cuda
271-
)
272237
if self.stream is not None:
273238
if not self.record_stream:
274-
torch_accelerator_module.current_stream().synchronize()
239+
self._torch_accelerator_module.current_stream().synchronize()
240+
275241
for group_module in self.modules:
276242
for param in group_module.parameters():
277243
param.data = self.cpu_param_dict[param]
@@ -282,15 +248,23 @@ def _offload_to_memory(self):
282248

283249
else:
284250
for group_module in self.modules:
285-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
251+
group_module.to(self.offload_device, non_blocking=False)
286252
for param in self.parameters:
287-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
253+
param.data = param.data.to(self.offload_device, non_blocking=False)
288254
for buffer in self.buffers:
289-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
255+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
256+
257+
@torch.compiler.disable()
258+
def onload_(self):
259+
r"""Onloads the group of parameters to the onload_device."""
260+
if self.offload_to_disk_path is not None:
261+
self._onload_from_disk()
262+
else:
263+
self._onload_from_memory()
290264

291265
@torch.compiler.disable()
292266
def offload_(self):
293-
r"""Offloads the group of modules to the offload_device."""
267+
r"""Offloads the group of parameters to the offload_device."""
294268
if self.offload_to_disk_path:
295269
self._offload_to_disk()
296270
else:
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
307281

308282
_is_stateful = False
309283

310-
def __init__(
311-
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
312-
) -> None:
284+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
313285
self.group = group
314-
self.next_group = next_group
286+
self.next_group: Optional[ModuleGroup] = None
315287
self.config = config
316288

317289
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -459,8 +431,8 @@ def pre_forward(self, module, *args, **kwargs):
459431

460432
def apply_group_offloading(
461433
module: torch.nn.Module,
462-
onload_device: torch.device,
463-
offload_device: torch.device = torch.device("cpu"),
434+
onload_device: Union[str, torch.device],
435+
offload_device: Union[str, torch.device] = torch.device("cpu"),
464436
offload_type: Union[str, GroupOffloadingType] = "block_level",
465437
num_blocks_per_group: Optional[int] = None,
466438
non_blocking: bool = False,
@@ -546,6 +518,8 @@ def apply_group_offloading(
546518
```
547519
"""
548520

521+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
522+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
549523
offload_type = GroupOffloadingType(offload_type)
550524

551525
stream = None
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
633607
# Apply group offloading hooks to the module groups
634608
for i, group in enumerate(matched_module_groups):
635609
for group_module in group.modules:
636-
_apply_group_offloading_hook(group_module, group, None, config=config)
610+
_apply_group_offloading_hook(group_module, group, config=config)
637611

638612
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
639613
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
662636
group_id=f"{module.__class__.__name__}_unmatched_group",
663637
)
664638
if config.stream is None:
665-
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
639+
_apply_group_offloading_hook(module, unmatched_group, config=config)
666640
else:
667-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
641+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
668642

669643

670644
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
693667
onload_self=True,
694668
group_id=name,
695669
)
696-
_apply_group_offloading_hook(submodule, group, None, config=config)
670+
_apply_group_offloading_hook(submodule, group, config=config)
697671
modules_with_group_offloading.add(name)
698672

699673
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
740714
onload_self=True,
741715
group_id=name,
742716
)
743-
_apply_group_offloading_hook(parent_module, group, None, config=config)
717+
_apply_group_offloading_hook(parent_module, group, config=config)
744718

745719
if config.stream is not None:
746720
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
762736
onload_self=True,
763737
group_id=_GROUP_ID_LAZY_LEAF,
764738
)
765-
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
739+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
766740

767741

768742
def _apply_group_offloading_hook(
769743
module: torch.nn.Module,
770744
group: ModuleGroup,
771-
next_group: Optional[ModuleGroup] = None,
772745
*,
773746
config: GroupOffloadingConfig,
774747
) -> None:
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
777750
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
778751
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
779752
if registry.get_hook(_GROUP_OFFLOADING) is None:
780-
hook = GroupOffloadingHook(group, next_group, config=config)
753+
hook = GroupOffloadingHook(group, config=config)
781754
registry.register_hook(hook, _GROUP_OFFLOADING)
782755

783756

784757
def _apply_lazy_group_offloading_hook(
785758
module: torch.nn.Module,
786759
group: ModuleGroup,
787-
next_group: Optional[ModuleGroup] = None,
788760
*,
789761
config: GroupOffloadingConfig,
790762
) -> None:
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
793765
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
794766
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
795767
if registry.get_hook(_GROUP_OFFLOADING) is None:
796-
hook = GroupOffloadingHook(group, next_group, config=config)
768+
hook = GroupOffloadingHook(group, config=config)
797769
registry.register_hook(hook, _GROUP_OFFLOADING)
798770

799771
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()

0 commit comments

Comments
 (0)