Skip to content

Commit 6281f1a

Browse files
committed
update saving and loading utils, making hebbian synapse use these utils for custom optimizer params saving and loading
1 parent eb534c4 commit 6281f1a

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ngclearn.components.synapses import DenseSynapse
1212
from ngclearn.utils import tensorstats
1313
from ngcsimlib import deprecate_args
14+
from ngclearn.utils.io_utils import save_pkl, load_pkl
1415

1516
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
1617
def _calc_update(
@@ -217,16 +218,12 @@ def __init__(
217218
def save(self, directory: str):
218219
super().save(directory)
219220
# Also save the optimizer parameters
220-
file_name = directory + "/" + self.name + "_opt_params" + ".pkl"
221-
with open(file_name, 'wb') as f:
222-
pickle.dump(self.opt_params.get(), f)
221+
save_pkl(directory, self.name + "_opt_params", self.opt_params.get())
223222

224223
def load(self, directory: str):
225224
super().load(directory)
226-
file_name = directory + "/" + self.name + "_opt_params" + ".pkl"
227-
with open(file_name, 'rb') as f:
228-
data = pickle.load(f)
229-
self.opt_params.set(data)
225+
# load the optimizer parameters in a custom way
226+
self.opt_params.set(load_pkl(directory, self.name + "_opt_params"))
230227

231228
@staticmethod
232229
def _compute_update(

ngclearn/utils/io_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# import jax
55
# from jax import numpy as jnp, grad, jit, vmap, random, lax
66
import os, sys, pickle
7+
from typing import Any
78

89
def serialize(fname, object): ## object "saving" routine
910
"""
@@ -65,3 +66,15 @@ def makedirs(directories):
6566
"""
6667
for dir in directories:
6768
makedir(dir)
69+
70+
71+
def save_pkl(directory: str, name: str, value: Any) -> None:
72+
file_name = directory + "/" + name + ".pkl"
73+
with open(file_name, 'wb') as f:
74+
pickle.dump(value, f)
75+
76+
def load_pkl(directory: str, name: str) -> Any:
77+
file_name = directory + "/" + name + ".pkl"
78+
with open(file_name, 'rb') as f:
79+
data = pickle.load(f)
80+
return data

0 commit comments

Comments
 (0)