13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from collections import defaultdict
17
- from typing import List , Optional , Set , Tuple
16
+ from typing import List , Optional
18
17
19
18
import torch
20
19
import torch .nn .utils .parametrize as P
@@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
57
56
self .name = name
58
57
self .scheme = scheme
59
58
self .generator = torch .Generator ()
60
- self .transforms = list ()
61
59
if seed is not None :
62
60
self .generator .manual_seed (seed )
63
61
@@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
101
99
for module , arg in tqdm .tqdm (modules_args , desc = desc , disable = (not use_tqdm )):
102
100
self ._apply_to_module (module , arg )
103
101
104
- self ._update_tied_weights ()
105
-
106
102
def _apply_to_module (self , module : Module , args : TransformArgs ):
107
103
"""
108
104
Create transforms and apply them to the module
@@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
120
116
# create transform as submodule
121
117
transform_name = f"{ self .name } _{ args .location } "
122
118
transform = self .create_transform (module , args )
123
- self .transforms .append (transform )
124
119
register_offload_module (module , transform_name , transform )
125
120
126
121
# register input transformation hook
@@ -165,31 +160,6 @@ def output_hook(_, _input, output):
165
160
else :
166
161
raise NotImplementedError ()
167
162
168
- def _update_tied_weights (self ):
169
- """
170
- Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171
- which is used by transformers to detect and remove shared pointers
172
- during saving
173
- """
174
- # map from data_ptrs to keys
175
- ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
176
- for transform in self .transforms :
177
- for name , param in transform .named_parameters (recurse = False ):
178
- # NOTE: previously asserted that parent._hf_hook.place_submodules=False
179
- if has_offloaded_params (transform ):
180
- param = transform ._hf_hook .weights_map [name ]
181
- ptr_to_keys [param .data_ptr ()].append ((transform , name ))
182
-
183
- # populate `_dynamic_tied_weights_keys` if there is more than one key
184
- # and ensure that they share tensors
185
- for shared_keys in ptr_to_keys .values ():
186
- if len (shared_keys ) > 1 :
187
- tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
188
-
189
- for transform , name in shared_keys :
190
- transform ._dynamic_tied_weights_keys .add (name )
191
- setattr (transform , name , tensor )
192
-
193
163
194
164
class TransformBase (InternalModule , ABC ):
195
165
"""
@@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC):
198
168
199
169
args : TransformArgs
200
170
weight : Parameter
201
- _dynamic_tied_weights_keys : Set [str ]
202
-
203
- def __init__ (self ):
204
- super ().__init__ ()
205
- self ._dynamic_tied_weights_keys = set ()
171
+ _dynamic_tied_weights_keys : List [str ] = ["weight" ]
206
172
207
173
@abstractmethod
208
174
def forward (self , value : Tensor ) -> Tensor :
0 commit comments