Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def onload_(self):

with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
Expand All @@ -98,6 +101,12 @@ def offload_(self):
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if self.parameters is not None:
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
Expand Down Expand Up @@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
Expand Down Expand Up @@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
Expand Down Expand Up @@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)


def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict


def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:
Expand Down
Loading