16
16
from typing import List , Tuple
17
17
18
18
import torch
19
+ from accelerate .utils import has_offloaded_params
19
20
from compressed_tensors import TRANSFORM_CONFIG_NAME
20
21
from compressed_tensors .transform import TransformConfig , TransformFactory
21
22
@@ -46,13 +47,19 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
46
47
47
48
def _tie_offloaded_tensors (model : torch .nn .Module ):
48
49
"""
49
- Populate the `_dynamic_tied_weights_keys` attribute of transforms,
50
- which is used by transformers to detect and remove shared pointers
51
- during saving
50
+ When accelerate replaces tensors with meta tensors during offloading, the meta
51
+ tensors may not be identical, even if the offloaded values are identical.
52
+
53
+ However, transformers can only serialize correctly if meta tensors are identical
54
+ (see transformers#39263).
55
+
56
+ This function collects all meta tensors which have shared offloaded values and sets
57
+ those tensors to be identical so that they can be removed during serialization
58
+
59
+ :param model: model potentially containing offloaded meta tensors to fix
52
60
"""
53
- from compressed_tensors .utils import has_offloaded_params
54
61
55
- # map from to keys
62
+ # map from offloaded tensor pointers to module-key locations
56
63
offloaded_ptrs : dict [int , List [Tuple [torch .nn .Module , str ]]] = defaultdict (list )
57
64
for module in model .modules ():
58
65
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
@@ -61,12 +68,12 @@ def _tie_offloaded_tensors(model: torch.nn.Module):
61
68
param = module ._hf_hook .weights_map [key ]
62
69
offloaded_ptrs [id (param )].append ((module , key ))
63
70
64
- # populate `_dynamic_tied_weights_keys` if there is more than one key
65
- # and ensure that they share tensors. In the case of offloading, this
71
+ # ensure that if a location shares an offloaded tensor pointers, that the
72
+ # meta tensor is also identical (assigned to the first element of the set)
66
73
for shared_keys in offloaded_ptrs .values ():
67
- if len (shared_keys ) > 1 :
68
- first_tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
69
- assert first_tensor .device .type == "meta"
70
- for module , key in shared_keys :
71
- assert getattr (module , key ).device .type == "meta"
72
- setattr (module , key , first_tensor )
74
+ assert len (shared_keys ) >= 1
75
+ first_tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
76
+ assert first_tensor .device .type == "meta"
77
+ for module , key in shared_keys :
78
+ assert getattr (module , key ).device .type == "meta"
79
+ setattr (module , key , first_tensor )
0 commit comments