Skip to content

Commit b3fa8c6

Browse files
committed
remove cpu param dict
1 parent 720be2b commit b3fa8c6

File tree

1 file changed

+20
-84
lines changed

1 file changed

+20
-84
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 20 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
buffers: Optional[List[torch.Tensor]] = None,
6060
non_blocking: bool = False,
6161
stream: Optional[torch.cuda.Stream] = None,
62-
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
6362
onload_self: bool = True,
6463
) -> None:
6564
self.modules = modules
@@ -71,47 +70,10 @@ def __init__(
7170
self.buffers = buffers
7271
self.non_blocking = non_blocking or stream is not None
7372
self.stream = stream
74-
self.cpu_param_dict = cpu_param_dict
7573
self.onload_self = onload_self
76-
# We still track if we've prepared the CPU dict for compatibility
77-
self.cpu_dict_prepared = False
78-
79-
if self.stream is not None and self.cpu_param_dict is None:
80-
# Now we'll create the dict on demand
81-
self.cpu_param_dict = {}
82-
83-
def pin_memory_(self):
84-
r"""Prepare the CPU parameter dictionary for reference (no pinning)."""
85-
if self.stream is not None and not self.cpu_dict_prepared:
86-
# Create a reference-only CPU parameter dict without pinning
87-
for group_module in self.modules:
88-
for param in group_module.parameters():
89-
if param.device == self.offload_device:
90-
# Store a reference without pinning
91-
self.cpu_param_dict[param] = param.data
92-
if self.parameters is not None:
93-
for param in self.parameters:
94-
if param.device == self.offload_device:
95-
# Store a reference without pinning
96-
self.cpu_param_dict[param] = param.data
97-
self.cpu_dict_prepared = True
98-
# For API compatibility
99-
self.pinned_memory = True
100-
101-
def unpin_memory_(self):
102-
r"""No-op method kept for API compatibility."""
103-
# No need to unpin since we're not pinning memory anymore
104-
if self.stream is not None and self.pinned_memory:
105-
# Just mark as unpinned for compatibility
106-
self.pinned_memory = False
107-
self.cpu_dict_prepared = False
10874

10975
def onload_(self):
11076
r"""Onloads the group of modules to the onload_device."""
111-
# Prepare CPU dict before onloading
112-
if self.stream is not None and not self.cpu_dict_prepared:
113-
self.pin_memory_() # This now just prepares the CPU dict
114-
11577
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
11678
if self.stream is not None:
11779
# Wait for previous Host->Device transfer to complete
@@ -129,20 +91,19 @@ def onload_(self):
12991

13092
def offload_(self):
13193
r"""Offloads the group of modules to the offload_device."""
94+
# Synchronize if using stream
13295
if self.stream is not None:
13396
torch.cuda.current_stream().synchronize()
134-
for group_module in self.modules:
135-
for param in group_module.parameters():
136-
param.data = self.cpu_param_dict[param]
137-
else:
138-
for group_module in self.modules:
139-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
140-
if self.parameters is not None:
141-
for param in self.parameters:
142-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
143-
if self.buffers is not None:
144-
for buffer in self.buffers:
145-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
97+
98+
# Use regular to() method for all cases - much simpler!
99+
for group_module in self.modules:
100+
group_module.to(self.offload_device, non_blocking=self.non_blocking)
101+
if self.parameters is not None:
102+
for param in self.parameters:
103+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
104+
if self.buffers is not None:
105+
for buffer in self.buffers:
106+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
146107

147108
# After offloading, we can unpin the memory if configured to do so
148109
# We'll keep it pinned by default for better performance
@@ -162,18 +123,13 @@ def __init__(
162123
self,
163124
group: ModuleGroup,
164125
next_group: Optional[ModuleGroup] = None,
165-
unpin_after_use: bool = False,
166126
) -> None:
167127
self.group = group
168128
self.next_group = next_group
169-
self.unpin_after_use = unpin_after_use
170129

171130
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
172131
if self.group.offload_leader == module:
173-
# Make sure we prepare CPU dict first (if using streams) before offloading
174-
if self.group.stream is not None and not self.group.cpu_dict_prepared:
175-
self.group.pin_memory_() # This now just prepares the CPU dict
176-
# Now it's safe to offload
132+
# Offload to CPU
177133
self.group.offload_()
178134
return module
179135

@@ -199,9 +155,6 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
199155
def post_forward(self, module: torch.nn.Module, output):
200156
if self.group.offload_leader == module:
201157
self.group.offload_()
202-
# This is now a no-op but kept for API compatibility
203-
if self.unpin_after_use and self.group.cpu_dict_prepared:
204-
self.group.unpin_memory_()
205158
return output
206159

207160

@@ -316,7 +269,6 @@ def apply_group_offloading(
316269
num_blocks_per_group: Optional[int] = None,
317270
non_blocking: bool = False,
318271
use_stream: bool = False,
319-
unpin_after_use: bool = False,
320272
) -> None:
321273
r"""
322274
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -363,8 +315,6 @@ def apply_group_offloading(
363315
use_stream (`bool`, defaults to `False`):
364316
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
365317
overlapping computation and data transfer.
366-
unpin_after_use (`bool`, defaults to `False`):
367-
Legacy parameter kept for API compatibility. Has no effect as we no longer use pinned memory.
368318
369319
Example:
370320
```python
@@ -382,7 +332,6 @@ def apply_group_offloading(
382332
... offload_type="block_level",
383333
... num_blocks_per_group=2,
384334
... use_stream=True,
385-
... unpin_after_use=False, # Legacy parameter, no effect
386335
... )
387336
```
388337
"""
@@ -409,11 +358,10 @@ def apply_group_offloading(
409358
onload_device,
410359
non_blocking,
411360
stream,
412-
unpin_after_use,
413361
)
414362
elif offload_type == "leaf_level":
415363
_apply_group_offloading_leaf_level(
416-
module, offload_device, onload_device, non_blocking, stream, unpin_after_use
364+
module, offload_device, onload_device, non_blocking, stream
417365
)
418366
else:
419367
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -426,7 +374,6 @@ def _apply_group_offloading_block_level(
426374
onload_device: torch.device,
427375
non_blocking: bool,
428376
stream: Optional[torch.cuda.Stream] = None,
429-
unpin_after_use: bool = False,
430377
) -> None:
431378
r"""
432379
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -447,9 +394,7 @@ def _apply_group_offloading_block_level(
447394
for overlapping computation and data transfer.
448395
"""
449396

450-
# With progressive pinning approach, we'll initialize an empty CPU parameter dict
451-
# and pin memory only when needed by each group
452-
cpu_param_dict = {} if stream is not None else None
397+
# We no longer need a CPU parameter dictionary
453398

454399
# Create module groups for ModuleList and Sequential blocks
455400
modules_with_group_offloading = set()
@@ -471,7 +416,6 @@ def _apply_group_offloading_block_level(
471416
onload_leader=current_modules[0],
472417
non_blocking=non_blocking,
473418
stream=stream,
474-
cpu_param_dict=cpu_param_dict,
475419
onload_self=stream is None,
476420
)
477421
matched_module_groups.append(group)
@@ -485,7 +429,7 @@ def _apply_group_offloading_block_level(
485429
)
486430

487431
for group_module in group.modules:
488-
_apply_group_offloading_hook(group_module, group, next_group, unpin_after_use)
432+
_apply_group_offloading_hook(group_module, group, next_group)
489433

490434
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
491435
# when the forward pass of this module is called. This is because the top-level module is not
@@ -508,11 +452,10 @@ def _apply_group_offloading_block_level(
508452
buffers=buffers,
509453
non_blocking=False,
510454
stream=None,
511-
cpu_param_dict=None,
512455
onload_self=True,
513456
)
514457
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
515-
_apply_group_offloading_hook(module, unmatched_group, next_group, unpin_after_use)
458+
_apply_group_offloading_hook(module, unmatched_group, next_group)
516459

517460

518461
def _apply_group_offloading_leaf_level(
@@ -521,7 +464,6 @@ def _apply_group_offloading_leaf_level(
521464
onload_device: torch.device,
522465
non_blocking: bool,
523466
stream: Optional[torch.cuda.Stream] = None,
524-
unpin_after_use: bool = False,
525467
) -> None:
526468
r"""
527469
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -544,9 +486,7 @@ def _apply_group_offloading_leaf_level(
544486
for overlapping computation and data transfer.
545487
"""
546488

547-
# With progressive pinning approach, we'll initialize an empty CPU parameter dict
548-
# and pin memory only when needed by each group
549-
cpu_param_dict = {} if stream is not None else None
489+
# We no longer need a CPU parameter dictionary
550490

551491
# Create module groups for leaf modules and apply group offloading hooks
552492
modules_with_group_offloading = set()
@@ -561,10 +501,9 @@ def _apply_group_offloading_leaf_level(
561501
onload_leader=submodule,
562502
non_blocking=non_blocking,
563503
stream=stream,
564-
cpu_param_dict=cpu_param_dict,
565504
onload_self=True,
566505
)
567-
_apply_group_offloading_hook(submodule, group, None, unpin_after_use)
506+
_apply_group_offloading_hook(submodule, group, None)
568507
modules_with_group_offloading.add(name)
569508

570509
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -606,10 +545,9 @@ def _apply_group_offloading_leaf_level(
606545
buffers=buffers,
607546
non_blocking=non_blocking,
608547
stream=stream,
609-
cpu_param_dict=cpu_param_dict,
610548
onload_self=True,
611549
)
612-
_apply_group_offloading_hook(parent_module, group, None, unpin_after_use)
550+
_apply_group_offloading_hook(parent_module, group, None)
613551

614552
if stream is not None:
615553
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -625,7 +563,6 @@ def _apply_group_offloading_leaf_level(
625563
buffers=None,
626564
non_blocking=False,
627565
stream=None,
628-
cpu_param_dict=None,
629566
onload_self=True,
630567
)
631568
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -635,14 +572,13 @@ def _apply_group_offloading_hook(
635572
module: torch.nn.Module,
636573
group: ModuleGroup,
637574
next_group: Optional[ModuleGroup] = None,
638-
unpin_after_use: bool = False,
639575
) -> None:
640576
registry = HookRegistry.check_if_exists_or_initialize(module)
641577

642578
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
643579
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
644580
if registry.get_hook(_GROUP_OFFLOADING) is None:
645-
hook = GroupOffloadingHook(group, next_group, unpin_after_use)
581+
hook = GroupOffloadingHook(group, next_group)
646582
registry.register_hook(hook, _GROUP_OFFLOADING)
647583

648584

0 commit comments

Comments
 (0)