|
15 | 15 | import hashlib |
16 | 16 | import os |
17 | 17 | from contextlib import contextmanager, nullcontext |
18 | | -from dataclasses import dataclass |
| 18 | +from dataclasses import dataclass, replace |
19 | 19 | from enum import Enum |
20 | 20 | from typing import Callable, Dict, List, Optional, Set, Tuple, Union |
21 | 21 |
|
@@ -751,10 +751,14 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf |
751 | 751 | for name, submodule in module.named_children(): |
752 | 752 | # Check if this is an explicitly defined block module |
753 | 753 | if block_modules and name in block_modules: |
754 | | - # Apply block offloading to the specified submodule |
755 | | - _apply_block_offloading_to_submodule( |
756 | | - submodule, name, config, modules_with_group_offloading, matched_module_groups |
757 | | - ) |
| 754 | + # Track submodule using a prefix to avoid filename collisions during disk offload. |
| 755 | + # Without this, submodules sharing the same model class would be assigned identical |
| 756 | + # filenames (derived from the class name). |
| 757 | + prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}." |
| 758 | + submodule_config = replace(config, module_prefix=prefix) |
| 759 | + |
| 760 | + _apply_group_offloading_block_level(submodule, submodule_config) |
| 761 | + modules_with_group_offloading.add(name) |
758 | 762 | elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): |
759 | 763 | # Handle ModuleList and Sequential blocks as before |
760 | 764 | for i in range(0, len(submodule), config.num_blocks_per_group): |
|
0 commit comments