Skip to content

Commit eb534c4

Browse files
committed
update hebbian synapse saving
1 parent 5e43ad2 commit eb534c4

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# %%
22

3+
import jax
4+
import pickle
35
from jax import random, numpy as jnp, jit
46
from functools import partial
57
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
@@ -206,10 +208,26 @@ def __init__(
206208
self.dBiases = Compartment(jnp.zeros(shape[1]))
207209

208210
#key, subkey = random.split(self.key.value)
211+
# NOTE: we don't save this compartment directly because it is a tuple can cannot be saved directly by numpy
209212
self.opt_params = Compartment(
210-
get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
213+
get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]),
214+
auto_save=False
211215
)
212216

217+
def save(self, directory: str):
218+
super().save(directory)
219+
# Also save the optimizer parameters
220+
file_name = directory + "/" + self.name + "_opt_params" + ".pkl"
221+
with open(file_name, 'wb') as f:
222+
pickle.dump(self.opt_params.get(), f)
223+
224+
def load(self, directory: str):
225+
super().load(directory)
226+
file_name = directory + "/" + self.name + "_opt_params" + ".pkl"
227+
with open(file_name, 'rb') as f:
228+
data = pickle.load(f)
229+
self.opt_params.set(data)
230+
213231
@staticmethod
214232
def _compute_update(
215233
w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights
@@ -332,3 +350,4 @@ def help(cls): ## component help function
332350
Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam',
333351
sign_value=-1.0, prior=("l1l2", 0.001))
334352
print(Wab)
353+
print(Wab.opt_params.get())

0 commit comments

Comments
 (0)