|
1 | 1 | # %% |
2 | 2 |
|
| 3 | +import jax |
| 4 | +import pickle |
3 | 5 | from jax import random, numpy as jnp, jit |
4 | 6 | from functools import partial |
5 | 7 | from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn |
@@ -206,10 +208,26 @@ def __init__( |
206 | 208 | self.dBiases = Compartment(jnp.zeros(shape[1])) |
207 | 209 |
|
208 | 210 | #key, subkey = random.split(self.key.value) |
| 211 | + # NOTE: we don't save this compartment directly because it is a tuple can cannot be saved directly by numpy |
209 | 212 | self.opt_params = Compartment( |
210 | | - get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]) |
| 213 | + get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]), |
| 214 | + auto_save=False |
211 | 215 | ) |
212 | 216 |
|
| 217 | + def save(self, directory: str): |
| 218 | + super().save(directory) |
| 219 | + # Also save the optimizer parameters |
| 220 | + file_name = directory + "/" + self.name + "_opt_params" + ".pkl" |
| 221 | + with open(file_name, 'wb') as f: |
| 222 | + pickle.dump(self.opt_params.get(), f) |
| 223 | + |
| 224 | + def load(self, directory: str): |
| 225 | + super().load(directory) |
| 226 | + file_name = directory + "/" + self.name + "_opt_params" + ".pkl" |
| 227 | + with open(file_name, 'rb') as f: |
| 228 | + data = pickle.load(f) |
| 229 | + self.opt_params.set(data) |
| 230 | + |
213 | 231 | @staticmethod |
214 | 232 | def _compute_update( |
215 | 233 | w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights |
@@ -332,3 +350,4 @@ def help(cls): ## component help function |
332 | 350 | Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam', |
333 | 351 | sign_value=-1.0, prior=("l1l2", 0.001)) |
334 | 352 | print(Wab) |
| 353 | + print(Wab.opt_params.get()) |
0 commit comments