@@ -103,9 +103,17 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
103
103
:param module: target module to apply transforms to
104
104
:param args: defines how the transform will be applied to the target module
105
105
"""
106
+ if has_offloaded_params (module ):
107
+ if module ._hf_hook .place_submodules :
108
+ raise NotImplementedError (
109
+ "Applying transforms to offloaded submodules with "
110
+ "`place_submodules=True` is not supported"
111
+ )
112
+
106
113
# create transform as submodule
107
114
transform_name = f"{ self .name } _{ args .location .value } "
108
115
transform = self .create_transform (module , args )
116
+ self .transforms .append (transform )
109
117
register_offload_module (module , transform_name , transform )
110
118
111
119
# register input transformation hook
@@ -136,8 +144,9 @@ def input_hook(_, args):
136
144
raise ValueError ("Offloaded training is not supported" )
137
145
P .register_parametrization (module , "weight" , transform )
138
146
139
- # transform is no longer needed (unfusing is not supported)
140
- delete_offload_module (module , transform_name )
147
+ else :
148
+ # transform is no longer needed (unfusing is not supported)
149
+ delete_offload_module (module , transform_name )
141
150
142
151
# register output transformation hook
143
152
elif args .location == TransformLocation .OUTPUT :
@@ -165,13 +174,20 @@ def _update_tied_weights(self):
165
174
ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
166
175
for transform in self .transforms :
167
176
for name , param in transform .named_parameters (recurse = False ):
177
+ # NOTE: previously asserted that parent._hf_hook.place_submodules=False
178
+ if has_offloaded_params (transform ):
179
+ param = transform ._hf_hook .weights_map [name ]
168
180
ptr_to_keys [param .data_ptr ()].append ((transform , name ))
169
181
170
182
# populate `_dynamic_tied_weights_keys` if there is more than one key
183
+ # and ensure that they share tensors
171
184
for shared_keys in ptr_to_keys .values ():
172
185
if len (shared_keys ) > 1 :
186
+ tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
187
+
173
188
for transform , name in shared_keys :
174
189
transform ._dynamic_tied_weights_keys .append (name )
190
+ setattr (transform , name , tensor )
175
191
176
192
177
193
class TransformBase (Module , ABC ):
0 commit comments