|
1 | 1 | from jax import random, numpy as jnp, jit |
2 | | -from ngcsimlib.compilers.process import transition |
3 | | -from ngcsimlib.component import Component |
4 | 2 | from ngcsimlib.compartment import Compartment |
5 | | - |
| 3 | +from ngcsimlib.parser import compilable |
6 | 4 | from ngclearn.utils.weight_distribution import initialize_params |
7 | | -from ngcsimlib.logger import info |
| 5 | + |
8 | 6 | from ngclearn.components.synapses.hebbian import TraceSTDPSynapse |
9 | | -from ngclearn.utils import tensorstats |
10 | 7 |
|
11 | 8 | class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligility traces |
12 | 9 | """ |
@@ -72,78 +69,69 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit |
72 | 69 |
|
73 | 70 | p_conn: probability of a connection existing (default: 1.); setting |
74 | 71 | this to < 1. will result in a sparser synaptic structure |
| 72 | +
|
| 73 | + w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) |
75 | 74 | """ |
76 | 75 |
|
77 | | - # Define Functions |
78 | 76 | def __init__( |
79 | 77 | self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0., |
80 | 78 | weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs |
81 | 79 | ): |
82 | | - super().__init__( |
| 80 | + super().__init__( # call to parent trace-stdp component |
83 | 81 | name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init, |
84 | 82 | resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs |
85 | 83 | ) |
86 | 84 | self.w_eps = 0. |
87 | 85 | self.tau_w = tau_w |
88 | 86 | ## MSTDP/MSTDP-ET meta-parameters |
89 | | - self.tau_elg = tau_elg |
90 | | - self.elg_decay = elg_decay |
| 87 | + self.tau_elg = tau_elg ## time constant for eligibility trace |
| 88 | + self.elg_decay = elg_decay ## decay factor eligibility trace |
91 | 89 | ## MSTDP/MSTDP-ET compartments |
92 | 90 | self.modulator = Compartment(jnp.zeros((self.batch_size, 1))) |
93 | 91 | self.eligibility = Compartment(jnp.zeros(shape)) |
94 | 92 | self.outmask = Compartment(jnp.zeros((1, shape[1]))) |
95 | 93 |
|
96 | | - @transition(output_compartments=["weights", "dWeights", "eligibility"]) |
97 | | - @staticmethod |
98 | | - def evolve( |
99 | | - dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike, |
100 | | - preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask |
101 | | - ): |
102 | | - # dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update |
103 | | - # dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights |
104 | | - # ) |
| 94 | + @compilable |
| 95 | + def evolve(self, dt, t): |
| 96 | + # dW_dt = self._compute_update() |
105 | 97 | # dWeights = dW_dt ## can think of this as eligibility at time t |
106 | 98 |
|
107 | | - if tau_elg > 0.: ## perform dynamics of M-STDP-ET |
108 | | - eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dWeights/tau_elg |
| 99 | + if self.tau_elg > 0.: ## perform dynamics of M-STDP-ET |
| 100 | + eligibility = self.eligibility.get() * jnp.exp(-dt / self.tau_elg) * self.elg_decay + self.dWeights.get()/self.tau_elg |
109 | 101 | else: ## otherwise, just do M-STDP |
110 | | - eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing |
| 102 | + eligibility = self.dWeights.get() ## dynamics of M-STDP had no eligibility tracing |
111 | 103 | ## do a gradient ascent update/shift |
112 | 104 | decayTerm = 0. |
113 | | - if tau_w > 0.: |
114 | | - decayTerm = weights * (1. / tau_w) |
115 | | - weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update |
| 105 | + if self.tau_w > 0.: |
| 106 | + decayTerm = self.weights.get() * (1. / self.tau_w) |
| 107 | + ## do modulated update |
| 108 | + weights = self.weights.get() + (eligibility * self.modulator.get() * self.eta) * self.outmask.get() - decayTerm |
116 | 109 |
|
117 | | - dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update |
118 | | - dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights |
119 | | - ) |
| 110 | + dW_dt = self._compute_update() ## apply a Hebbian/STDP rule to obtain a non-modulated update |
120 | 111 | dWeights = dW_dt ## can think of this as eligibility at time t |
121 | 112 |
|
122 | 113 | #w_eps = 0.01 |
123 | | - weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound)) |
124 | | - |
125 | | - return weights, dWeights, eligibility |
126 | | - |
127 | | - @transition( |
128 | | - output_compartments=[ |
129 | | - "inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask" |
130 | | - ] |
131 | | - ) |
132 | | - @staticmethod |
133 | | - def reset(batch_size, shape): |
134 | | - preVals = jnp.zeros((batch_size, shape[0])) |
135 | | - postVals = jnp.zeros((batch_size, shape[1])) |
136 | | - synVals = jnp.zeros(shape) |
137 | | - inputs = preVals |
138 | | - outputs = postVals |
139 | | - preSpike = preVals |
140 | | - postSpike = postVals |
141 | | - preTrace = preVals |
142 | | - postTrace = postVals |
143 | | - dWeights = synVals |
144 | | - eligibility = synVals |
145 | | - outmask = postVals + 1. |
146 | | - return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask |
| 114 | + weights = jnp.clip(weights, self.w_eps, self.w_bound - self.w_eps) # jnp.abs(w_bound)) |
| 115 | + self.weights.set(weights) |
| 116 | + self.dWeights.set(dWeights) |
| 117 | + self.eligibility.set(eligibility) |
| 118 | + |
| 119 | + @compilable |
| 120 | + def reset(self): |
| 121 | + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) |
| 122 | + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) |
| 123 | + synVals = jnp.zeros(self.shape.get()) |
| 124 | + |
| 125 | + if not self.inputs.targeted: |
| 126 | + self.inputs.set(preVals) |
| 127 | + self.outputs.set(postVals) |
| 128 | + self.preSpike.set(preVals) |
| 129 | + self.postSpike.set(postVals) |
| 130 | + self.preTrace.set(preVals) |
| 131 | + self.postTrace.set(postVals) |
| 132 | + self.dWeights.set(synVals) |
| 133 | + self.eligibility.set(synVals) |
| 134 | + self.outmask.set(postVals + 1.) |
147 | 135 |
|
148 | 136 | @classmethod |
149 | 137 | def help(cls): ## component help function |
@@ -195,17 +183,3 @@ def help(cls): ## component help function |
195 | 183 | "dW^{stdp}_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i", |
196 | 184 | "hyperparameters": hyperparams} |
197 | 185 | return info |
198 | | - |
199 | | - def __repr__(self): |
200 | | - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
201 | | - maxlen = max(len(c) for c in comps) + 5 |
202 | | - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
203 | | - for c in comps: |
204 | | - stats = tensorstats(getattr(self, c).value) |
205 | | - if stats is not None: |
206 | | - line = [f"{k}: {v}" for k, v in stats.items()] |
207 | | - line = ", ".join(line) |
208 | | - else: |
209 | | - line = "None" |
210 | | - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
211 | | - return lines |
0 commit comments