Skip to content

Commit 1905a2c

Browse files
authored
Remove deepcopy from FunctionalModule (#228)
This changes the semantics, but it makes in more in line with nn._stateless.
1 parent dbd2bb9 commit 1905a2c

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

functorch/_src/make_functional.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
3636
else:
3737
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
3838

39+
def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
40+
if len(names) == 1:
41+
return getattr(obj, names[0])
42+
else:
43+
_get_nested_attr(getattr(obj, names[0]), names[1:])
44+
3945
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
4046
"""
4147
This function removes all the Parameters from the model and
@@ -69,6 +75,14 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a
6975
_del_nested_attr(mod, name.split("."))
7076
_set_nested_attr(mod, name.split("."), p)
7177

78+
def _swap_state(mod: nn.Module, split_names: List[str], elems):
79+
result = []
80+
for split_name, elem in zip(split_names, elems):
81+
result.append(_get_nested_attr(mod, split_name))
82+
_del_nested_attr(mod, split_name)
83+
_set_nested_attr(mod, split_name, elem)
84+
return result
85+
7286
def extract_buffers(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
7387
orig_params = tuple(mod.buffers())
7488
# Remove all the parameters in the model
@@ -181,13 +195,16 @@ def fun(weights, buffers, data):
181195

182196
return weights, buffers, fun, weight_descriptors, buf_descriptors
183197

198+
def make_split_names(lst):
199+
return [name.split('.') for name in lst]
184200

185201
class FunctionalModuleWithBuffers(nn.Module):
186202
def __init__(self, stateless_model, param_names, buffer_names):
187203
super(FunctionalModuleWithBuffers, self).__init__()
188204
self.stateless_model = stateless_model
189205
self.param_names = param_names
190206
self.buffer_names = buffer_names
207+
self.split_names = make_split_names(param_names + buffer_names)
191208

192209
@staticmethod
193210
def _create_from(model):
@@ -201,22 +218,24 @@ def _create_from(model):
201218
buffers,
202219
)
203220

204-
def with_state(self, params, buffers):
205-
stateful_model = copy.deepcopy(self.stateless_model)
206-
load_weights(stateful_model, self.param_names, params)
207-
load_buffers(stateful_model, self.buffer_names, buffers)
208-
return stateful_model
209-
210221
def forward(self, params, buffers, *args, **kwargs):
211-
stateful_model = self.with_state(params, buffers)
212-
return stateful_model(*args, **kwargs)
213-
222+
# Temporarily load the state back onto self.stateless_model
223+
old_state = _swap_state(
224+
self.stateless_model,
225+
self.split_names,
226+
list(params) + list(buffers))
227+
try:
228+
return self.stateless_model(*args, **kwargs)
229+
finally:
230+
# Remove the loaded state on self.stateless_model
231+
_swap_state(self.stateless_model, self.split_names, old_state)
214232

215233
class FunctionalModule(nn.Module):
216234
def __init__(self, stateless_model, param_names):
217235
super(FunctionalModule, self).__init__()
218236
self.stateless_model = stateless_model
219237
self.param_names = param_names
238+
self.split_names = make_split_names(param_names)
220239

221240
@staticmethod
222241
def _create_from(model):
@@ -225,15 +244,14 @@ def _create_from(model):
225244
params, param_names = extract_weights(model_copy)
226245
return FunctionalModule(model_copy, param_names), params
227246

228-
def with_state(self, params):
229-
stateful_model = copy.deepcopy(self.stateless_model)
230-
load_weights(stateful_model, self.param_names, params)
231-
return stateful_model
232-
233247
def forward(self, params, *args, **kwargs):
234-
stateful_model = self.with_state(params)
235-
return stateful_model(*args, **kwargs)
236-
248+
# Temporarily load the state back onto self.stateless_model
249+
old_state = _swap_state(self.stateless_model, self.split_names, params)
250+
try:
251+
return self.stateless_model(*args, **kwargs)
252+
finally:
253+
# Remove the loaded state on self.stateless_model
254+
_swap_state(self.stateless_model, self.split_names, old_state)
237255

238256
def make_functional(model: nn.Module):
239257
"""make_functional(model) -> func, weights

0 commit comments

Comments
 (0)