Skip to content

Commit 1bd4539

Browse files
committed
Fix disk offload block_modules recursion to avoid extra files
1 parent 005e51b commit 1bd4539

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import hashlib
1616
import os
1717
from contextlib import contextmanager, nullcontext
18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, replace
1919
from enum import Enum
2020
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
2121

@@ -751,10 +751,14 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
751751
for name, submodule in module.named_children():
752752
# Check if this is an explicitly defined block module
753753
if block_modules and name in block_modules:
754-
# Apply block offloading to the specified submodule
755-
_apply_block_offloading_to_submodule(
756-
submodule, name, config, modules_with_group_offloading, matched_module_groups
757-
)
754+
# Track submodule using a prefix to avoid filename collisions during disk offload.
755+
# Without this, submodules sharing the same model class would be assigned identical
756+
# filenames (derived from the class name).
757+
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
758+
submodule_config = replace(config, module_prefix=prefix)
759+
760+
_apply_group_offloading_block_level(submodule, submodule_config)
761+
modules_with_group_offloading.add(name)
758762
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
759763
# Handle ModuleList and Sequential blocks as before
760764
for i in range(0, len(submodule), config.num_blocks_per_group):

0 commit comments

Comments
 (0)