Skip to content

Commit 60bcc74

Browse files
committed
update
1 parent 60f468a commit 60bcc74

File tree

2 files changed

+19
-26
lines changed

2 files changed

+19
-26
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import nullcontext, contextmanager
15+
from contextlib import contextmanager, nullcontext
1616
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
@@ -102,9 +102,7 @@ def onload_(self):
102102
with self._pinned_memory_tensors() as pinned_memory:
103103
for module in self.modules:
104104
for param in module.parameters():
105-
param.data = pinned_memory[param].to(
106-
self.onload_device, non_blocking=self.non_blocking
107-
)
105+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
108106
else:
109107
for group_module in self.modules:
110108
for param in group_module.parameters():
@@ -392,7 +390,9 @@ def apply_group_offloading(
392390
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
393391
)
394392
elif offload_type == "leaf_level":
395-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage)
393+
_apply_group_offloading_leaf_level(
394+
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
395+
)
396396
else:
397397
raise ValueError(f"Unsupported offload_type: {offload_type}")
398398

@@ -425,11 +425,6 @@ def _apply_group_offloading_block_level(
425425
for overlapping computation and data transfer.
426426
"""
427427

428-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
429-
cpu_param_dict = None
430-
if stream is not None:
431-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
432-
433428
# Create module groups for ModuleList and Sequential blocks
434429
modules_with_group_offloading = set()
435430
unmatched_modules = []
@@ -522,11 +517,6 @@ def _apply_group_offloading_leaf_level(
522517
for overlapping computation and data transfer.
523518
"""
524519

525-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
526-
cpu_param_dict = None
527-
if stream is not None:
528-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
529-
530520
# Create module groups for leaf modules and apply group offloading hooks
531521
modules_with_group_offloading = set()
532522
for name, submodule in module.named_modules():
@@ -641,19 +631,15 @@ def _apply_lazy_group_offloading_hook(
641631
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
642632

643633

644-
def _get_cpu_param_dict(module: torch.nn.Module, low_cpu_mem_usage: bool = False) -> Dict[torch.nn.Parameter, torch.Tensor]:
634+
def _get_cpu_param_dict(
635+
module: torch.nn.Module, low_cpu_mem_usage: bool = False
636+
) -> Dict[torch.nn.Parameter, torch.Tensor]:
645637
cpu_param_dict = {}
646638
for param in module.parameters():
647-
if low_cpu_mem_usage:
648-
cpu_param_dict[param] = param.data.cpu()
649-
else:
650-
cpu_param_dict[param] = param.data.cpu().pin_memory()
639+
cpu_param_dict[param] = param.data.cpu() if low_cpu_mem_usage else param.data.cpu().pin_memory()
651640

652641
for buffer in module.buffers():
653-
if low_cpu_mem_usage:
654-
cpu_param_dict[buffer] = buffer.data.cpu()
655-
else:
656-
cpu_param_dict[buffer] = buffer.data.cpu().pin_memory()
642+
cpu_param_dict[buffer] = buffer.data.cpu() if low_cpu_mem_usage else buffer.data.cpu().pin_memory()
657643

658644
return cpu_param_dict
659645

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549-
low_cpu_mem_usage=False
549+
low_cpu_mem_usage=False,
550550
) -> None:
551551
r"""
552552
Activates group offloading for the current model.
@@ -585,7 +585,14 @@ def enable_group_offload(
585585
f"open an issue at https://github.com/huggingface/diffusers/issues."
586586
)
587587
apply_group_offloading(
588-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream, low_cpu_mem_usage=low_cpu_mem_usage
588+
self,
589+
onload_device,
590+
offload_device,
591+
offload_type,
592+
num_blocks_per_group,
593+
non_blocking,
594+
use_stream,
595+
low_cpu_mem_usage=low_cpu_mem_usage,
589596
)
590597

591598
def save_pretrained(

0 commit comments

Comments
 (0)