-
Couldn't load subscription status.
- Fork 6.5k
Fix unique memory address when doing group-offloading with disk #11767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6c0f20
7c8fc64
6639f25
a9b7abe
e0bfef9
24ac17f
e37d2b0
99d5ad5
4f081dc
ab2eff7
9710bbc
6901ef4
72d76a8
e75ef18
59d07e5
b572234
9553b79
e8fef13
54a299c
2749d4b
27d41ac
260a834
e6d8779
a891f9b
0946d96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import hashlib | ||
| import os | ||
| from contextlib import contextmanager, nullcontext | ||
| from dataclasses import dataclass | ||
|
|
@@ -37,7 +38,7 @@ | |
| _GROUP_OFFLOADING = "group_offloading" | ||
| _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | ||
| _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" | ||
|
|
||
| _GROUP_ID_LAZY_LEAF = "lazy_leafs" | ||
| _SUPPORTED_PYTORCH_LAYERS = ( | ||
| torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, | ||
| torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, | ||
|
|
@@ -82,6 +83,7 @@ def __init__( | |
| low_cpu_mem_usage: bool = False, | ||
| onload_self: bool = True, | ||
| offload_to_disk_path: Optional[str] = None, | ||
| group_id: Optional[int] = None, | ||
| ) -> None: | ||
| self.modules = modules | ||
| self.offload_device = offload_device | ||
|
|
@@ -100,7 +102,10 @@ def __init__( | |
| self._is_offloaded_to_disk = False | ||
|
|
||
| if self.offload_to_disk_path: | ||
| self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") | ||
| # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. | ||
| self.group_id = group_id if group_id is not None else str(id(self)) | ||
| short_hash = _compute_group_hash(self.group_id) | ||
| self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") | ||
|
|
||
| all_tensors = [] | ||
| for module in self.modules: | ||
|
|
@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
|
|
||
| for i in range(0, len(submodule), config.num_blocks_per_group): | ||
| current_modules = submodule[i : i + config.num_blocks_per_group] | ||
| group_id = f"{name}_{i}_{i + len(current_modules) - 1}" | ||
| group = ModuleGroup( | ||
| modules=current_modules, | ||
| offload_device=config.offload_device, | ||
|
|
@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
| record_stream=config.record_stream, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=group_id, | ||
| ) | ||
| matched_module_groups.append(group) | ||
| for j in range(i, i + len(current_modules)): | ||
|
|
@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
| stream=None, | ||
| record_stream=False, | ||
| onload_self=True, | ||
| group_id=f"{module.__class__.__name__}_unmatched_group", | ||
| ) | ||
| if config.stream is None: | ||
| _apply_group_offloading_hook(module, unmatched_group, None, config=config) | ||
|
|
@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff | |
| record_stream=config.record_stream, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=name, | ||
| ) | ||
| _apply_group_offloading_hook(submodule, group, None, config=config) | ||
| modules_with_group_offloading.add(name) | ||
|
|
@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff | |
| record_stream=config.record_stream, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=name, | ||
| ) | ||
| _apply_group_offloading_hook(parent_module, group, None, config=config) | ||
|
|
||
|
|
@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff | |
| record_stream=False, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=_GROUP_ID_LAZY_LEAF, | ||
| ) | ||
| _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) | ||
|
|
||
|
|
@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: | |
| raise ValueError("Group offloading is not enabled for the provided module.") | ||
|
|
||
|
|
||
| def _compute_group_hash(group_id): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think we need to hash the group id strings, they should be unique already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() | ||
| # first 16 characters for a reasonably short but unique name | ||
| return hashed_id[:16] | ||
|
|
||
|
|
||
| def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: | ||
| r""" | ||
| Removes the group offloading hook from the module and re-applies it. This is useful when the module has been | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The passed in
group_idargument should be unique no? I don't think we need to compute a hash.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any edge cases, either. But having a hash is a bit more future-proof to me.