Skip to content

Commit dd21357

Browse files
committed
update
1 parent 60bcc74 commit dd21357

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,18 @@ def __init__(
7373

7474
self.cpu_param_dict = {}
7575
for module in self.modules:
76-
self.cpu_param_dict.update(_get_cpu_param_dict(module, self.low_cpu_mem_usage))
76+
for param in module.parameters():
77+
self.cpu_param_dict[param] = (
78+
param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
79+
)
80+
81+
for param in self.parameters:
82+
self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
83+
84+
for buffer in self.buffers:
85+
self.cpu_param_dict[buffer] = (
86+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
87+
)
7788

7889
@contextmanager
7990
def _pinned_memory_tensors(self):
@@ -100,20 +111,30 @@ def onload_(self):
100111
with context:
101112
if self.stream is not None:
102113
with self._pinned_memory_tensors() as pinned_memory:
103-
for module in self.modules:
104-
for param in module.parameters():
114+
for group_module in self.modules:
115+
for param in group_module.parameters():
105116
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
117+
118+
if self.parameters is not None:
119+
for param in self.parameters:
120+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
121+
122+
if self.buffers is not None:
123+
for buffer in self.buffers:
124+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
125+
106126
else:
107127
for group_module in self.modules:
108128
for param in group_module.parameters():
109129
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
110130

111-
if self.parameters is not None:
112-
for param in self.parameters:
113-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
114-
if self.buffers is not None:
115-
for buffer in self.buffers:
116-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
131+
if self.parameters is not None:
132+
for param in self.parameters:
133+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
134+
135+
if self.buffers is not None:
136+
for buffer in self.buffers:
137+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
117138

118139
def offload_(self):
119140
r"""Offloads the group of modules to the offload_device."""
@@ -631,7 +652,7 @@ def _apply_lazy_group_offloading_hook(
631652
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
632653

633654

634-
def _get_cpu_param_dict(
655+
def _assign_cpu_param_dict(
635656
module: torch.nn.Module, low_cpu_mem_usage: bool = False
636657
) -> Dict[torch.nn.Parameter, torch.Tensor]:
637658
cpu_param_dict = {}

0 commit comments

Comments
 (0)