Skip to content

Commit 2210d28

Browse files
committed
update
1 parent 3be6706 commit 2210d28

File tree

2 files changed

+84
-51
lines changed

2 files changed

+84
-51
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 75 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import nullcontext
15+
from contextlib import contextmanager, nullcontext
1616
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
@@ -56,23 +56,50 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59-
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
59+
low_cpu_mem_usage=False,
6060
onload_self: bool = True,
6161
) -> None:
6262
self.modules = modules
6363
self.offload_device = offload_device
6464
self.onload_device = onload_device
6565
self.offload_leader = offload_leader
6666
self.onload_leader = onload_leader
67-
self.parameters = parameters
68-
self.buffers = buffers
67+
self.parameters = parameters or []
68+
self.buffers = buffers or []
6969
self.non_blocking = non_blocking or stream is not None
7070
self.stream = stream
71-
self.cpu_param_dict = cpu_param_dict
7271
self.onload_self = onload_self
72+
self.low_cpu_mem_usage = low_cpu_mem_usage
7373

74-
if self.stream is not None and self.cpu_param_dict is None:
75-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
74+
self.cpu_param_dict = {}
75+
for module in self.modules:
76+
for param in module.parameters():
77+
self.cpu_param_dict[param] = (
78+
param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
79+
)
80+
81+
for param in self.parameters:
82+
self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
83+
84+
for buffer in self.buffers:
85+
self.cpu_param_dict[buffer] = (
86+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
87+
)
88+
89+
@contextmanager
90+
def _pinned_memory_tensors(self):
91+
pinned_dict = {}
92+
try:
93+
for param, tensor in self.cpu_param_dict.items():
94+
if not tensor.is_pinned():
95+
pinned_dict[param] = tensor.pin_memory()
96+
else:
97+
pinned_dict[param] = tensor
98+
99+
yield pinned_dict
100+
101+
finally:
102+
pinned_dict = None
76103

77104
def onload_(self):
78105
r"""Onloads the group of modules to the onload_device."""
@@ -82,17 +109,32 @@ def onload_(self):
82109
self.stream.synchronize()
83110

84111
with context:
85-
for group_module in self.modules:
86-
for param in group_module.parameters():
87-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88-
for buffer in group_module.buffers():
89-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
90-
if self.parameters is not None:
91-
for param in self.parameters:
92-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
93-
if self.buffers is not None:
94-
for buffer in self.buffers:
95-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
112+
if self.stream is not None:
113+
with self._pinned_memory_tensors() as pinned_memory:
114+
for group_module in self.modules:
115+
for param in group_module.parameters():
116+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
117+
118+
if self.parameters is not None:
119+
for param in self.parameters:
120+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
121+
122+
if self.buffers is not None:
123+
for buffer in self.buffers:
124+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
125+
126+
else:
127+
for group_module in self.modules:
128+
for param in group_module.parameters():
129+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
130+
131+
if self.parameters is not None:
132+
for param in self.parameters:
133+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
134+
135+
if self.buffers is not None:
136+
for buffer in self.buffers:
137+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
96138

97139
def offload_(self):
98140
r"""Offloads the group of modules to the offload_device."""
@@ -108,12 +150,12 @@ def offload_(self):
108150
for buffer in self.buffers:
109151
buffer.data = self.cpu_param_dict[buffer]
110152
else:
111-
for group_module in self.modules:
112-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
113-
if self.parameters is not None:
153+
for module in self.modules:
154+
module.to(self.offload_device, non_blocking=self.non_blocking)
155+
if self.parameters:
114156
for param in self.parameters:
115157
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
116-
if self.buffers is not None:
158+
if self.buffers:
117159
for buffer in self.buffers:
118160
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
119161

@@ -284,6 +326,7 @@ def apply_group_offloading(
284326
num_blocks_per_group: Optional[int] = None,
285327
non_blocking: bool = False,
286328
use_stream: bool = False,
329+
low_cpu_mem_usage=False,
287330
) -> None:
288331
r"""
289332
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +408,12 @@ def apply_group_offloading(
365408
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
366409

367410
_apply_group_offloading_block_level(
368-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
411+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
369412
)
370413
elif offload_type == "leaf_level":
371-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
414+
_apply_group_offloading_leaf_level(
415+
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
416+
)
372417
else:
373418
raise ValueError(f"Unsupported offload_type: {offload_type}")
374419

@@ -380,6 +425,7 @@ def _apply_group_offloading_block_level(
380425
onload_device: torch.device,
381426
non_blocking: bool,
382427
stream: Optional[torch.cuda.Stream] = None,
428+
low_cpu_mem_usage: bool = False,
383429
) -> None:
384430
r"""
385431
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +446,6 @@ def _apply_group_offloading_block_level(
400446
for overlapping computation and data transfer.
401447
"""
402448

403-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
404-
cpu_param_dict = None
405-
if stream is not None:
406-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
407-
408449
# Create module groups for ModuleList and Sequential blocks
409450
modules_with_group_offloading = set()
410451
unmatched_modules = []
@@ -425,7 +466,7 @@ def _apply_group_offloading_block_level(
425466
onload_leader=current_modules[0],
426467
non_blocking=non_blocking,
427468
stream=stream,
428-
cpu_param_dict=cpu_param_dict,
469+
low_cpu_mem_usage=low_cpu_mem_usage,
429470
onload_self=stream is None,
430471
)
431472
matched_module_groups.append(group)
@@ -462,7 +503,6 @@ def _apply_group_offloading_block_level(
462503
buffers=buffers,
463504
non_blocking=False,
464505
stream=None,
465-
cpu_param_dict=None,
466506
onload_self=True,
467507
)
468508
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -475,6 +515,7 @@ def _apply_group_offloading_leaf_level(
475515
onload_device: torch.device,
476516
non_blocking: bool,
477517
stream: Optional[torch.cuda.Stream] = None,
518+
low_cpu_mem_usage: bool = False,
478519
) -> None:
479520
r"""
480521
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +538,6 @@ def _apply_group_offloading_leaf_level(
497538
for overlapping computation and data transfer.
498539
"""
499540

500-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
501-
cpu_param_dict = None
502-
if stream is not None:
503-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
504-
505541
# Create module groups for leaf modules and apply group offloading hooks
506542
modules_with_group_offloading = set()
507543
for name, submodule in module.named_modules():
@@ -515,7 +551,7 @@ def _apply_group_offloading_leaf_level(
515551
onload_leader=submodule,
516552
non_blocking=non_blocking,
517553
stream=stream,
518-
cpu_param_dict=cpu_param_dict,
554+
low_cpu_mem_usage=low_cpu_mem_usage,
519555
onload_self=True,
520556
)
521557
_apply_group_offloading_hook(submodule, group, None)
@@ -560,7 +596,7 @@ def _apply_group_offloading_leaf_level(
560596
buffers=buffers,
561597
non_blocking=non_blocking,
562598
stream=stream,
563-
cpu_param_dict=cpu_param_dict,
599+
low_cpu_mem_usage=low_cpu_mem_usage,
564600
onload_self=True,
565601
)
566602
_apply_group_offloading_hook(parent_module, group, None)
@@ -579,7 +615,7 @@ def _apply_group_offloading_leaf_level(
579615
buffers=None,
580616
non_blocking=False,
581617
stream=None,
582-
cpu_param_dict=None,
618+
low_cpu_mem_usage=low_cpu_mem_usage,
583619
onload_self=True,
584620
)
585621
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -616,17 +652,6 @@ def _apply_lazy_group_offloading_hook(
616652
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
617653

618654

619-
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
620-
cpu_param_dict = {}
621-
for param in module.parameters():
622-
param.data = param.data.cpu().pin_memory()
623-
cpu_param_dict[param] = param.data
624-
for buffer in module.buffers():
625-
buffer.data = buffer.data.cpu().pin_memory()
626-
cpu_param_dict[buffer] = buffer.data
627-
return cpu_param_dict
628-
629-
630655
def _gather_parameters_with_no_group_offloading_parent(
631656
module: torch.nn.Module, modules_with_group_offloading: Set[str]
632657
) -> List[torch.nn.Parameter]:

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
low_cpu_mem_usage=False,
549550
) -> None:
550551
r"""
551552
Activates group offloading for the current model.
@@ -584,7 +585,14 @@ def enable_group_offload(
584585
f"open an issue at https://github.com/huggingface/diffusers/issues."
585586
)
586587
apply_group_offloading(
587-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
588+
self,
589+
onload_device,
590+
offload_device,
591+
offload_type,
592+
num_blocks_per_group,
593+
non_blocking,
594+
use_stream,
595+
low_cpu_mem_usage=low_cpu_mem_usage,
588596
)
589597

590598
def save_pretrained(

0 commit comments

Comments
 (0)