Skip to content

Commit 40b8e28

Browse files
committed
Merge pull request #3 from lucasb-eyer/save-load-params
[WIP] Add utility to save/load parameters, i.e. models.
2 parents b835835 + 4761d35 commit 40b8e28

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

DeepFried2/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,22 @@ def create_param_state_as(other, initial_value=0, prefix='state_for_'):
1919
broadcastable=other.broadcastable,
2020
name=prefix + str(other.name)
2121
)
22+
23+
24+
def count_params(module):
25+
params, _ = module.parameters()
26+
return sum(p.get_value().size for p in params)
27+
28+
29+
def save_params(module, where, compress=False):
30+
params, _ = module.parameters()
31+
32+
savefn = _np.savez_compressed if compress else _np.savez
33+
savefn(where, params=[p.get_value() for p in params])
34+
35+
36+
def load_params(module, fromwhere):
37+
params, _ = module.parameters()
38+
with _np.load(fromwhere) as f:
39+
for p, v in zip(params, f['params']):
40+
p.set_value(v)

0 commit comments

Comments
 (0)