|
1 | 1 | from jax import random, numpy as jnp, jit |
2 | | -from ngcsimlib.compilers.process import transition |
3 | | -from ngcsimlib.component import Component |
4 | | -from ngcsimlib.compartment import Compartment |
5 | | - |
6 | 2 | from ngclearn.utils.weight_distribution import initialize_params |
7 | 3 | from ngcsimlib.logger import info |
| 4 | + |
8 | 5 | from ngclearn.components.synapses import DenseSynapse |
9 | | -from ngclearn.utils import tensorstats |
| 6 | +from ngcsimlib.compartment import Compartment |
| 7 | +from ngcsimlib.parser import compilable |
10 | 8 |
|
11 | 9 | class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable |
12 | 10 | """ |
@@ -56,80 +54,82 @@ class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable |
56 | 54 | resources_int: initialization kernel for synaptic resources matrix |
57 | 55 | """ |
58 | 56 |
|
59 | | - # Define Functions |
60 | | - def __init__(self, name, shape, weight_init=None, bias_init=None, |
61 | | - resist_scale=1., p_conn=1., tau_f=750., tau_d=50., |
62 | | - resources_init=None, **kwargs): |
| 57 | + def __init__( |
| 58 | + self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., tau_f=750., tau_d=50., |
| 59 | + resources_init=None, **kwargs |
| 60 | + ): |
63 | 61 | super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) |
64 | 62 | ## STP meta-parameters |
65 | 63 | self.resources_init = resources_init |
66 | 64 | self.tau_f = tau_f |
67 | 65 | self.tau_d = tau_d |
68 | 66 |
|
69 | 67 | ## Set up short-term plasticity / dynamic synapse compartment values |
70 | | - tmp_key, *subkeys = random.split(self.key.value, 4) |
| 68 | + tmp_key, *subkeys = random.split(self.key.get(), 4) |
71 | 69 | preVals = jnp.zeros((self.batch_size, shape[0])) |
72 | 70 | self.u = Compartment(preVals) ## release prob variables |
73 | 71 | self.x = Compartment(preVals + 1) ## resource availability variables |
74 | | - self.Wdyn = Compartment(self.weights.value * 0) ## dynamic synapse values |
| 72 | + self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values |
75 | 73 | if self.resources_init is None: |
76 | 74 | info(self.name, "is using default resources value initializer!") |
77 | 75 | self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15 |
78 | 76 | self.resources = Compartment( |
79 | 77 | initialize_params(subkeys[2], self.resources_init, shape) |
80 | 78 | ) ## matrix U - synaptic resources matrix |
81 | 79 |
|
82 | | - @transition(output_compartments=["outputs", "u", "x", "Wdyn"]) |
83 | | - @staticmethod |
84 | | - def advance_state( |
85 | | - tau_f, tau_d, Rscale, inputs, weights, biases, resources, u, x, Wdyn |
86 | | - ): |
87 | | - s = inputs |
| 80 | + @compilable |
| 81 | + def advance_state(self, t, dt): |
| 82 | + s = self.inputs.get() |
88 | 83 | ## compute short-term facilitation |
89 | 84 | #u = u - u * (1./tau_f) + (resources * (1. - u)) * s |
90 | | - if tau_f > 0.: ## compute short-term facilitation |
91 | | - u = u - u * (1./tau_f) + (resources * (1. - u)) * s |
| 85 | + if self.tau_f > 0.: ## compute short-term facilitation |
| 86 | + u = self.u.get() - self.u.get() * (1./self.tau_f) + (self.resources.get() * (1. - self.u.get())) * s |
92 | 87 | else: |
93 | | - u = resources ## disabling STF yields fixed resource u variables |
| 88 | + u = self.resources.get() ## disabling STF yields fixed resource u variables |
94 | 89 | ## compute dynamic synaptic values/conductances |
95 | | - Wdyn = (weights * u * x) * s + Wdyn * (1. - s) ## OR: -W/tau_w + W * u * x |
96 | | - if tau_d > 0.: |
97 | | - ## compute short-term depression |
98 | | - x = x + (1. - x) * (1./tau_d) - u * x * s |
99 | | - outputs = jnp.matmul(inputs, Wdyn * Rscale) + biases |
100 | | - return outputs, u, x, Wdyn |
101 | | - |
102 | | - @transition(output_compartments=["inputs", "outputs", "u", "x", "Wdyn"]) |
103 | | - @staticmethod |
104 | | - def reset(batch_size, shape): |
105 | | - preVals = jnp.zeros((batch_size, shape[0])) |
106 | | - postVals = jnp.zeros((batch_size, shape[1])) |
107 | | - inputs = preVals |
108 | | - outputs = postVals |
109 | | - u = preVals |
110 | | - x = preVals + 1 |
111 | | - Wdyn = jnp.zeros(shape) |
112 | | - return inputs, outputs, u, x, Wdyn |
113 | | - |
114 | | - def save(self, directory, **kwargs): |
115 | | - file_name = directory + "/" + self.name + ".npz" |
116 | | - if self.bias_init != None: |
117 | | - jnp.savez(file_name, |
118 | | - weights=self.weights.value, |
119 | | - biases=self.biases.value, |
120 | | - resources=self.resources.value) |
121 | | - else: |
122 | | - jnp.savez(file_name, |
123 | | - weights=self.weights.value, |
124 | | - resources=self.resources.value) |
125 | | - |
126 | | - def load(self, directory, **kwargs): |
127 | | - file_name = directory + "/" + self.name + ".npz" |
128 | | - data = jnp.load(file_name) |
129 | | - self.weights.set(data['weights']) |
130 | | - self.resources.set(data['resources']) |
131 | | - if "biases" in data.keys(): |
132 | | - self.biases.set(data['biases']) |
| 90 | + Wdyn = (self.weights.get() * u * self.x.get()) * s + self.Wdyn.get() * (1. - s) ## OR: -W/tau_w + W * u * x |
| 91 | + ## compute short-term depression |
| 92 | + x = self.x.get() |
| 93 | + if self.tau_d > 0.: |
| 94 | + x = x + (1. - x) * (1./self.tau_d) - u * x * s |
| 95 | + ## else, do nothing with x (keep it pointing to current x compartment) |
| 96 | + outputs = jnp.matmul(self.inputs.get(), Wdyn * self.resist_scale) + self.biases.get() |
| 97 | + |
| 98 | + self.outputs.set(outputs) |
| 99 | + self.u.set(u) |
| 100 | + self.x.set(x) |
| 101 | + self.Wdyn.set(Wdyn) |
| 102 | + |
| 103 | + @compilable |
| 104 | + def reset(self): |
| 105 | + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) |
| 106 | + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) |
| 107 | + if not self.inputs.targeted: |
| 108 | + self.inputs.set(preVals) |
| 109 | + self.outputs.set(postVals) |
| 110 | + self.u.set(preVals) |
| 111 | + self.x.set(preVals + 1) |
| 112 | + self.Wdyn.set(jnp.zeros(self.shape.get())) |
| 113 | + |
| 114 | + # def save(self, directory, **kwargs): |
| 115 | + # file_name = directory + "/" + self.name + ".npz" |
| 116 | + # if self.bias_init != None: |
| 117 | + # jnp.savez(file_name, |
| 118 | + # weights=self.weights.value, |
| 119 | + # biases=self.biases.value, |
| 120 | + # resources=self.resources.value) |
| 121 | + # else: |
| 122 | + # jnp.savez(file_name, |
| 123 | + # weights=self.weights.value, |
| 124 | + # resources=self.resources.value) |
| 125 | + # |
| 126 | + # def load(self, directory, **kwargs): |
| 127 | + # file_name = directory + "/" + self.name + ".npz" |
| 128 | + # data = jnp.load(file_name) |
| 129 | + # self.weights.set(data['weights']) |
| 130 | + # self.resources.set(data['resources']) |
| 131 | + # if "biases" in data.keys(): |
| 132 | + # self.biases.set(data['biases']) |
133 | 133 |
|
134 | 134 | @classmethod |
135 | 135 | def help(cls): ## component help function |
@@ -166,17 +166,3 @@ def help(cls): ## component help function |
166 | 166 | "dW/dt = W_full * u * x * inputs", |
167 | 167 | "hyperparameters": hyperparams} |
168 | 168 | return info |
169 | | - |
170 | | - def __repr__(self): |
171 | | - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
172 | | - maxlen = max(len(c) for c in comps) + 5 |
173 | | - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
174 | | - for c in comps: |
175 | | - stats = tensorstats(getattr(self, c).value) |
176 | | - if stats is not None: |
177 | | - line = [f"{k}: {v}" for k, v in stats.items()] |
178 | | - line = ", ".join(line) |
179 | | - else: |
180 | | - line = "None" |
181 | | - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
182 | | - return lines |
0 commit comments