|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from ngcsimlib.compilers.process import transition |
| 3 | +from ngcsimlib.component import Component |
| 4 | +from ngcsimlib.compartment import Compartment |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from ngclearn.components.synapses import DenseSynapse |
| 10 | +from ngclearn.utils import tensorstats |
| 11 | +from ngclearn.utils.model_utils import create_function |
| 12 | + |
| 13 | + |
| 14 | +class REINFORCESynapse(DenseSynapse): |
| 15 | + |
| 16 | + # Define Functions |
| 17 | + def __init__( |
| 18 | + self, name, shape, eta=1e-4, weight_init=None, resist_scale=1., act_fx=None, |
| 19 | + p_conn=1., w_bound=1., batch_size=1, **kwargs |
| 20 | + ): |
| 21 | + # This is because we have weights mu and weight log sigma |
| 22 | + input_dim, output_dim = shape |
| 23 | + super().__init__(name, (input_dim, output_dim * 2), weight_init, None, resist_scale, |
| 24 | + p_conn, batch_size=batch_size, **kwargs) |
| 25 | + |
| 26 | + ## Synaptic hyper-parameters |
| 27 | + self.shape = shape ## shape of synaptic efficacy matrix |
| 28 | + self.Rscale = resist_scale ## post-transformation scale factor |
| 29 | + self.w_bound = w_bound #1. ## soft weight constraint |
| 30 | + self.eta = eta ## learning rate |
| 31 | + |
| 32 | + ## Compartment setup |
| 33 | + self.dWeights = Compartment(self.weights.value * 0) |
| 34 | + # self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later |
| 35 | + self.objective = Compartment(jnp.zeros(())) |
| 36 | + self.outputs = Compartment(jnp.zeros((batch_size, output_dim))) |
| 37 | + self.rewards = Compartment(jnp.zeros((batch_size,))) # the normalized reward (r - r_hat), input compartment |
| 38 | + self.act_fx, self.dact_fx = create_function(act_fx if act_fx is not None else "identity") |
| 39 | + # self.seed = Component(seed) |
| 40 | + |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def _compute_update(dt, inputs, rewards, act_fx, weights): |
| 44 | + W_mu, W_logstd = jnp.split(weights, 2, axis=-1) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim) |
| 45 | + # Forward pass |
| 46 | + activation = act_fx(inputs) |
| 47 | + mean = activation @ W_mu |
| 48 | + logstd = activation @ W_logstd |
| 49 | + std = jnp.exp(logstd.clip(-10.0, 2.0)) |
| 50 | + # Sample using reparameterization trick |
| 51 | + epsilon = jnp.asarray(np.random.normal(0, 1, mean.shape)) |
| 52 | + sample = epsilon * std + mean |
| 53 | + outputs = sample # the actual action that we take |
| 54 | + # Compute log probability density of the Gaussian |
| 55 | + log_prob = -0.5 * jnp.log(2 * jnp.pi) - logstd - 0.5 * ((sample - mean) / std) ** 2 |
| 56 | + log_prob = log_prob.sum(-1) |
| 57 | + # Compute objective (negative REINFORCE objective) |
| 58 | + objective = (-log_prob * rewards).mean() * 1e-2 |
| 59 | + # Backward pass |
| 60 | + # Compute gradients manually based on the derivation |
| 61 | + # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * (sample-mu)/sigma^2 |
| 62 | + dlog_prob_dmean = (sample - mean) / (std ** 2) |
| 63 | + # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1) |
| 64 | + dlog_prob_dlogstd = ((sample - mean) / std) ** 2 - 1.0 |
| 65 | + # Compute gradients with respect to weights |
| 66 | + # Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T |
| 67 | + # Similarly for W_logstd |
| 68 | + dL_dWmu = activation.T @ (-rewards[:, None] * dlog_prob_dmean) * 1e-2 |
| 69 | + dL_dWlstd = activation.T @ (-rewards[:, None] * dlog_prob_dlogstd) * 1e-2 |
| 70 | + # Update weights |
| 71 | + dW = jnp.concatenate([dL_dWmu, dL_dWlstd], axis=-1) |
| 72 | + # Finally, return metrics if needed |
| 73 | + return dW, objective, outputs |
| 74 | + |
| 75 | + @transition(output_compartments=["weights", "dWeights", "objective", "outputs"]) |
| 76 | + @staticmethod |
| 77 | + def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta): |
| 78 | + dWeights, objective, outputs = REINFORCESynapse._compute_update( |
| 79 | + dt, inputs, rewards, act_fx, weights |
| 80 | + ) |
| 81 | + ## do a gradient ascent update/shift |
| 82 | + weights = weights + dWeights * eta |
| 83 | + ## enforce non-negativity |
| 84 | + eps = 0.01 # 0.001 |
| 85 | + weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound)) |
| 86 | + return weights, dWeights, objective, outputs |
| 87 | + |
| 88 | + @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights"]) |
| 89 | + @staticmethod |
| 90 | + def reset(batch_size, shape): |
| 91 | + preVals = jnp.zeros((batch_size, shape[0])) |
| 92 | + postVals = jnp.zeros((batch_size, shape[1])) |
| 93 | + inputs = preVals |
| 94 | + outputs = postVals |
| 95 | + objective = jnp.zeros(()) |
| 96 | + rewards = jnp.zeros((batch_size,)) |
| 97 | + dWeights = jnp.zeros(shape) |
| 98 | + return inputs, outputs, objective, rewards, dWeights |
| 99 | + |
| 100 | + @classmethod |
| 101 | + def help(cls): ## component help function |
| 102 | + properties = { |
| 103 | + |
| 104 | + } |
| 105 | + compartment_props = { |
| 106 | + |
| 107 | + } |
| 108 | + hyperparams = { |
| 109 | + |
| 110 | + } |
| 111 | + info = {cls.__name__: properties, |
| 112 | + "compartments": compartment_props, |
| 113 | + "dynamics": "outputs = [(W * Rscale) * inputs] ;" |
| 114 | + "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i", |
| 115 | + "hyperparameters": hyperparams} |
| 116 | + return info |
| 117 | + |
| 118 | + def __repr__(self): |
| 119 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 120 | + maxlen = max(len(c) for c in comps) + 5 |
| 121 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 122 | + for c in comps: |
| 123 | + stats = tensorstats(getattr(self, c).value) |
| 124 | + if stats is not None: |
| 125 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 126 | + line = ", ".join(line) |
| 127 | + else: |
| 128 | + line = "None" |
| 129 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 130 | + return lines |
0 commit comments