From ad8443fe9ea06feb890efcc1b7f3b673ac140669 Mon Sep 17 00:00:00 2001 From: lucasb-eyer Date: Thu, 23 Jul 2015 13:39:13 +0200 Subject: [PATCH] Add utility to save/load parameters, i.e. models. Also adds a utility to compute the number of parameters, because that's always interesting and often reported in papers. --- beacon8/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/beacon8/utils.py b/beacon8/utils.py index 2cad645..9f540d8 100644 --- a/beacon8/utils.py +++ b/beacon8/utils.py @@ -19,3 +19,20 @@ def create_param_state_as(other, initial_value=0): broadcastable=other.broadcastable, name='state_for_' + str(other.name) ) + + +def count_params(module): + params, _ = module.parameters() + return sum(p.get_value().size for p in params) + + +def save_params(module, where): + params, _ = module.parameters() + _np.savez_compressed(where, params=[p.get_value() for p in params]) + + +def load_params(module, fromwhere): + params, _ = module.parameters() + with _np.load(fromwhere) as f: + for p, v in zip(params, f['params']): + p.set_value(v)