13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from typing import List , Optional
16
+ from collections import defaultdict
17
+ from typing import List , Optional , Set , Tuple
17
18
18
19
import torch
19
20
import torch .nn .utils .parametrize as P
@@ -56,6 +57,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
56
57
self .name = name
57
58
self .scheme = scheme
58
59
self .generator = torch .Generator ()
60
+ self .transforms = list ()
59
61
if seed is not None :
60
62
self .generator .manual_seed (seed )
61
63
@@ -99,6 +101,8 @@ def apply_to_model(self, model: Module, use_tqdm=True):
99
101
for module , arg in tqdm .tqdm (modules_args , desc = desc , disable = (not use_tqdm )):
100
102
self ._apply_to_module (module , arg )
101
103
104
+ self ._update_tied_weights ()
105
+
102
106
def _apply_to_module (self , module : Module , args : TransformArgs ):
103
107
"""
104
108
Create transforms and apply them to the module
@@ -116,6 +120,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
116
120
# create transform as submodule
117
121
transform_name = f"{ self .name } _{ args .location } "
118
122
transform = self .create_transform (module , args )
123
+ self .transforms .append (transform )
119
124
register_offload_module (module , transform_name , transform )
120
125
121
126
# register input transformation hook
@@ -160,6 +165,31 @@ def output_hook(_, _input, output):
160
165
else :
161
166
raise NotImplementedError ()
162
167
168
+ def _update_tied_weights (self ):
169
+ """
170
+ Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171
+ which is used by transformers to detect and remove shared pointers
172
+ during saving
173
+ """
174
+ # map from data_ptrs to keys
175
+ ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
176
+ for transform in self .transforms :
177
+ for name , param in transform .named_parameters (recurse = False ):
178
+ # NOTE: previously asserted that parent._hf_hook.place_submodules=False
179
+ if has_offloaded_params (transform ):
180
+ param = transform ._hf_hook .weights_map [name ]
181
+ ptr_to_keys [param .data_ptr ()].append ((transform , name ))
182
+
183
+ # populate `_dynamic_tied_weights_keys` if there is more than one key
184
+ # and ensure that they share tensors
185
+ for shared_keys in ptr_to_keys .values ():
186
+ if len (shared_keys ) > 1 :
187
+ tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
188
+
189
+ for transform , name in shared_keys :
190
+ transform ._dynamic_tied_weights_keys .add (name )
191
+ setattr (transform , name , tensor )
192
+
163
193
164
194
class TransformBase (InternalModule , ABC ):
165
195
"""
@@ -168,7 +198,11 @@ class TransformBase(InternalModule, ABC):
168
198
169
199
args : TransformArgs
170
200
weight : Parameter
171
- _dynamic_tied_weights_keys : List [str ] = ["weight" ]
201
+ _dynamic_tied_weights_keys : Set [str ]
202
+
203
+ def __init__ (self ):
204
+ super ().__init__ ()
205
+ self ._dynamic_tied_weights_keys = set ()
172
206
173
207
@abstractmethod
174
208
def forward (self , value : Tensor ) -> Tensor :
0 commit comments