diff --git a/beacon8/utils.py b/beacon8/utils.py index 2cad645..9f540d8 100644 --- a/beacon8/utils.py +++ b/beacon8/utils.py @@ -19,3 +19,20 @@ def create_param_state_as(other, initial_value=0): broadcastable=other.broadcastable, name='state_for_' + str(other.name) ) + + +def count_params(module): + params, _ = module.parameters() + return sum(p.get_value().size for p in params) + + +def save_params(module, where): + params, _ = module.parameters() + _np.savez_compressed(where, params=[p.get_value() for p in params]) + + +def load_params(module, fromwhere): + params, _ = module.parameters() + with _np.load(fromwhere) as f: + for p, v in zip(params, f['params']): + p.set_value(v)