@@ -36,6 +36,12 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
36
36
else :
37
37
_set_nested_attr (getattr (obj , names [0 ]), names [1 :], value )
38
38
39
+ def _get_nested_attr (obj : nn .Module , names : List [str ]) -> None :
40
+ if len (names ) == 1 :
41
+ return getattr (obj , names [0 ])
42
+ else :
43
+ _get_nested_attr (getattr (obj , names [0 ]), names [1 :])
44
+
39
45
def extract_weights (mod : nn .Module ) -> Tuple [Tuple [Tensor , ...], List [str ]]:
40
46
"""
41
47
This function removes all the Parameters from the model and
@@ -69,6 +75,14 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a
69
75
_del_nested_attr (mod , name .split ("." ))
70
76
_set_nested_attr (mod , name .split ("." ), p )
71
77
78
+ def _swap_state (mod : nn .Module , split_names : List [str ], elems ):
79
+ result = []
80
+ for split_name , elem in zip (split_names , elems ):
81
+ result .append (_get_nested_attr (mod , split_name ))
82
+ _del_nested_attr (mod , split_name )
83
+ _set_nested_attr (mod , split_name , elem )
84
+ return result
85
+
72
86
def extract_buffers (mod : nn .Module ) -> Tuple [Tuple [Tensor , ...], List [str ]]:
73
87
orig_params = tuple (mod .buffers ())
74
88
# Remove all the parameters in the model
@@ -181,13 +195,16 @@ def fun(weights, buffers, data):
181
195
182
196
return weights , buffers , fun , weight_descriptors , buf_descriptors
183
197
198
+ def make_split_names (lst ):
199
+ return [name .split ('.' ) for name in lst ]
184
200
185
201
class FunctionalModuleWithBuffers (nn .Module ):
186
202
def __init__ (self , stateless_model , param_names , buffer_names ):
187
203
super (FunctionalModuleWithBuffers , self ).__init__ ()
188
204
self .stateless_model = stateless_model
189
205
self .param_names = param_names
190
206
self .buffer_names = buffer_names
207
+ self .split_names = make_split_names (param_names + buffer_names )
191
208
192
209
@staticmethod
193
210
def _create_from (model ):
@@ -201,22 +218,24 @@ def _create_from(model):
201
218
buffers ,
202
219
)
203
220
204
- def with_state (self , params , buffers ):
205
- stateful_model = copy .deepcopy (self .stateless_model )
206
- load_weights (stateful_model , self .param_names , params )
207
- load_buffers (stateful_model , self .buffer_names , buffers )
208
- return stateful_model
209
-
210
221
def forward (self , params , buffers , * args , ** kwargs ):
211
- stateful_model = self .with_state (params , buffers )
212
- return stateful_model (* args , ** kwargs )
213
-
222
+ # Temporarily load the state back onto self.stateless_model
223
+ old_state = _swap_state (
224
+ self .stateless_model ,
225
+ self .split_names ,
226
+ list (params ) + list (buffers ))
227
+ try :
228
+ return self .stateless_model (* args , ** kwargs )
229
+ finally :
230
+ # Remove the loaded state on self.stateless_model
231
+ _swap_state (self .stateless_model , self .split_names , old_state )
214
232
215
233
class FunctionalModule (nn .Module ):
216
234
def __init__ (self , stateless_model , param_names ):
217
235
super (FunctionalModule , self ).__init__ ()
218
236
self .stateless_model = stateless_model
219
237
self .param_names = param_names
238
+ self .split_names = make_split_names (param_names )
220
239
221
240
@staticmethod
222
241
def _create_from (model ):
@@ -225,15 +244,14 @@ def _create_from(model):
225
244
params , param_names = extract_weights (model_copy )
226
245
return FunctionalModule (model_copy , param_names ), params
227
246
228
- def with_state (self , params ):
229
- stateful_model = copy .deepcopy (self .stateless_model )
230
- load_weights (stateful_model , self .param_names , params )
231
- return stateful_model
232
-
233
247
def forward (self , params , * args , ** kwargs ):
234
- stateful_model = self .with_state (params )
235
- return stateful_model (* args , ** kwargs )
236
-
248
+ # Temporarily load the state back onto self.stateless_model
249
+ old_state = _swap_state (self .stateless_model , self .split_names , params )
250
+ try :
251
+ return self .stateless_model (* args , ** kwargs )
252
+ finally :
253
+ # Remove the loaded state on self.stateless_model
254
+ _swap_state (self .stateless_model , self .split_names , old_state )
237
255
238
256
def make_functional (model : nn .Module ):
239
257
"""make_functional(model) -> func, weights
0 commit comments