Skip to content

Commit 295d18b

Browse files
authored
Merge pull request #101 from lucasb-eyer/safe-setstate
A whole bunch of sanity-checks in parameter loading.
2 parents 8116441 + b21ef87 commit 295d18b

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

DeepFried2/Container.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __getstate__(self):
6464
return [m.__getstate__() for m in self.modules]
6565

6666
def __setstate__(self, state):
67+
if len(self.modules) != len(state):
68+
raise ValueError("{} wants to load params for {} modules but received params for {} modules".format(df.utils.typename(self), len(self.modules), len(state)))
69+
6770
for m, s in zip(self.modules, state):
6871
m.__setstate__(s)
6972

DeepFried2/Module.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,5 +189,12 @@ def __getstate__(self):
189189
return [p.get_value() for p in self.parameters()]
190190

191191
def __setstate__(self, state):
192-
for p, s in zip(self.parameters(), state):
192+
params = self.parameters()
193+
if len(params) != len(state):
194+
raise ValueError("{} wants to load {} params but received {} params".format(df.utils.typename(self), len(params), len(state)))
195+
196+
for p, s in zip(params, state):
197+
if p.get_value().shape != s.shape:
198+
raise ValueError("{} got invalid shape when loading param {}: expecting {} but loading {}".format(df.utils.typename(self), p.param.name, p.get_value().shape, s.shape))
199+
193200
p.set_value(s)

DeepFried2/layers/BatchNormalization.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __getstate__(self):
107107
return [buf.get_value() for buf in (self.buf_mean, self.buf_var, self.buf_count)] + regular
108108

109109
def __setstate__(self, state):
110-
istate = iter(state)
111-
for buf, val in zip((self.buf_mean, self.buf_var, self.buf_count), istate):
110+
for buf, val in zip((self.buf_mean, self.buf_var, self.buf_count), state):
112111
buf.set_value(val)
113-
df.Module.__setstate__(self, istate)
112+
df.Module.__setstate__(self, state[3:])

0 commit comments

Comments
 (0)