Skip to content

Commit 2bf33c8

Browse files
committed
better comments
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7aec12c commit 2bf33c8

File tree

1 file changed

+20
-13
lines changed
  • src/compressed_tensors/transform

1 file changed

+20
-13
lines changed

src/compressed_tensors/transform/apply.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Tuple
1717

1818
import torch
19+
from accelerate.utils import has_offloaded_params
1920
from compressed_tensors import TRANSFORM_CONFIG_NAME
2021
from compressed_tensors.transform import TransformConfig, TransformFactory
2122

@@ -46,13 +47,19 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
4647

4748
def _tie_offloaded_tensors(model: torch.nn.Module):
4849
"""
49-
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
50-
which is used by transformers to detect and remove shared pointers
51-
during saving
50+
When accelerate replaces tensors with meta tensors during offloading, the meta
51+
tensors may not be identical, even if the offloaded values are identical.
52+
53+
However, transformers can only serialize correctly if meta tensors are identical
54+
(see transformers#39263).
55+
56+
This function collects all meta tensors which have shared offloaded values and sets
57+
those tensors to be identical so that they can be removed during serialization
58+
59+
:param model: model potentially containing offloaded meta tensors to fix
5260
"""
53-
from compressed_tensors.utils import has_offloaded_params
5461

55-
# map from to keys
62+
# map from offloaded tensor pointers to module-key locations
5663
offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list)
5764
for module in model.modules():
5865
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
@@ -61,12 +68,12 @@ def _tie_offloaded_tensors(model: torch.nn.Module):
6168
param = module._hf_hook.weights_map[key]
6269
offloaded_ptrs[id(param)].append((module, key))
6370

64-
# populate `_dynamic_tied_weights_keys` if there is more than one key
65-
# and ensure that they share tensors. In the case of offloading, this
71+
# ensure that if a location shares an offloaded tensor pointers, that the
72+
# meta tensor is also identical (assigned to the first element of the set)
6673
for shared_keys in offloaded_ptrs.values():
67-
if len(shared_keys) > 1:
68-
first_tensor = getattr(shared_keys[0][0], shared_keys[0][1])
69-
assert first_tensor.device.type == "meta"
70-
for module, key in shared_keys:
71-
assert getattr(module, key).device.type == "meta"
72-
setattr(module, key, first_tensor)
74+
assert len(shared_keys) >= 1
75+
first_tensor = getattr(shared_keys[0][0], shared_keys[0][1])
76+
assert first_tensor.device.type == "meta"
77+
for module, key in shared_keys:
78+
assert getattr(module, key).device.type == "meta"
79+
setattr(module, key, first_tensor)

0 commit comments

Comments
 (0)