14
14
15
15
from abc import ABC , abstractmethod
16
16
from collections import defaultdict
17
- from typing import List , Optional , Tuple
17
+ from typing import List , Optional , Tuple , Set
18
18
19
19
import torch
20
20
import torch .nn .utils .parametrize as P
@@ -164,10 +164,6 @@ def _update_tied_weights(self):
164
164
which is used by transformers to detect and remove shared pointers
165
165
during saving
166
166
"""
167
- # avoid issues with this method being called twice
168
- for transform in self .transforms :
169
- transform ._dynamic_tied_weights_keys = list ()
170
-
171
167
# map from data_ptrs to keys
172
168
ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
173
169
for transform in self .transforms :
@@ -184,7 +180,7 @@ def _update_tied_weights(self):
184
180
tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
185
181
186
182
for transform , name in shared_keys :
187
- transform ._dynamic_tied_weights_keys .append (name )
183
+ transform ._dynamic_tied_weights_keys .add (name )
188
184
setattr (transform , name , tensor )
189
185
190
186
@@ -195,11 +191,11 @@ class TransformBase(InternalModule, ABC):
195
191
196
192
args : TransformArgs
197
193
weight : Parameter
198
- _dynamic_tied_weights_keys : List [str ]
194
+ _dynamic_tied_weights_keys : Set [str ]
199
195
200
196
def __init__ (self ):
201
197
super ().__init__ ()
202
- self ._dynamic_tied_weights_keys = list ()
198
+ self ._dynamic_tied_weights_keys = set ()
203
199
204
200
@abstractmethod
205
201
def forward (self , value : Tensor ) -> Tensor :
0 commit comments