Skip to content

Commit ee69971

Browse files
committed
Add utility to save/load parameters, i.e. models.
Also adds a utility to compute the number of parameters, because that's always interesting and often reported in papers.
1 parent dd6bbf0 commit ee69971

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

DeepFried2/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ def create_param_state_as(other, initial_value=0, prefix='state_for_'):
1919
broadcastable=other.broadcastable,
2020
name=prefix + str(other.name)
2121
)
22+
23+
24+
def count_params(module):
25+
params, _ = module.parameters()
26+
return sum(p.get_value().size for p in params)
27+
28+
29+
def save_params(module, where):
30+
params, _ = module.parameters()
31+
_np.savez_compressed(where, params=[p.get_value() for p in params])
32+
33+
34+
def load_params(module, fromwhere):
35+
params, _ = module.parameters()
36+
with _np.load(fromwhere) as f:
37+
for p, v in zip(params, f['params']):
38+
p.set_value(v)

0 commit comments

Comments
 (0)