|
| 1 | +from ngclearn.components.jaxComponent import JaxComponent |
| 2 | +from jax import numpy as jnp, random, jit, nn |
| 3 | +from functools import partial |
| 4 | +from ngclearn.utils import tensorstats |
| 5 | +from ngcsimlib.deprecators import deprecate_args |
| 6 | +from ngcsimlib.logger import info, warn |
| 7 | +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ |
| 8 | + step_euler, step_rk2 |
| 9 | + |
| 10 | +from ngcsimlib.compilers.process import transition |
| 11 | +#from ngcsimlib.component import Component |
| 12 | +from ngcsimlib.compartment import Compartment |
| 13 | + |
| 14 | + |
| 15 | +def _calc_biophysical_constants(v): ## computes H-H biophysical constants (which are functions of voltage v) |
| 16 | + alpha_n_of_v = .01 * ((10 - v) / (jnp.exp((10. - v) / 10.) - 1.)) |
| 17 | + beta_n_of_v = .125 * jnp.exp(-v / 80.) |
| 18 | + alpha_m_of_v = .1 * ((25 - v) / (jnp.exp((25. - v) / 10.) - 1.)) |
| 19 | + beta_m_of_v = 4. * jnp.exp(-v / 18.) |
| 20 | + alpha_h_of_v = .07 * jnp.exp(-v / 20.) |
| 21 | + beta_h_of_v = 1. / (jnp.exp((30 - v) / 10.) + 1.) |
| 22 | + return alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v |
| 23 | + |
| 24 | +def _dv_dt(t, v, j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L): ## ODE for membrane potential/voltage |
| 25 | + ## C dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L) |
| 26 | + term1 = g_Na * jnp.power(m, 3) * h * (v - v_Na) |
| 27 | + term2 = g_K * jnp.power(n, 4) * (v - v_K) |
| 28 | + term3 = g_L * (v - v_L) |
| 29 | + return (j - term1 - term2 - term3) * (1. / tau_v) |
| 30 | + |
| 31 | +def dv_dt(t, v, params): |
| 32 | + j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L = params |
| 33 | + return _dv_dt(t, v, j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L) |
| 34 | + |
| 35 | +def _dx_dt(t, x, alpha_x_of_v, beta_x_of_v): ## ODE for channel/gate |
| 36 | + ## dx/dt = alpha_x(v) * (1 - x) - beta_x(v) * x |
| 37 | + return alpha_x_of_v * (1 - x) - beta_x_of_v * x |
| 38 | + |
| 39 | +def dx_dt(t, x, params): |
| 40 | + alpha_x_of_v, beta_x_of_v = params |
| 41 | + return _dx_dt(t, x, alpha_x_of_v, beta_x_of_v) |
| 42 | + |
| 43 | +class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell |
| 44 | + """ |
| 45 | + A spiking cell based the Hodgkin-Huxley (H-H) 1952 set of dynamics for describing the ionic mechanisms that underwrite |
| 46 | + the initiation and propagation of action potentials within a (giant) squid axon. |
| 47 | +
|
| 48 | + The four differential equations for adjusting this specific cell |
| 49 | + (for adjusting v, given current j, over time) is: |
| 50 | +
|
| 51 | + | tau_v dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L) |
| 52 | + | dn/dt = alpha_n(v) * (1 - n) - beta_n(v) * n |
| 53 | + | dm/dt = alpha_m(v) * (1 - m) - beta_m(v) * m |
| 54 | + | dh/dt = alpha_h(v) * (1 - h) - beta_h(v) * h |
| 55 | + | where alpha_x(v) and beta_x(v) are functions that produce relevant biophysical constant values |
| 56 | + | depending on which gate/channel is being probed (i.e., x = n or m or h) |
| 57 | +
|
| 58 | + | --- Cell Input Compartments: --- |
| 59 | + | j - electrical current input (takes in external signals) |
| 60 | + | --- Cell State Compartments: --- |
| 61 | + | v - membrane potential/voltage state |
| 62 | + | n - dimensionless probabilities for potassium channel subunit activation |
| 63 | + | m - dimensionless probabilities for sodium channel subunit activation |
| 64 | + | h - dimensionless probabilities for sodium channel subunit inactivation |
| 65 | + | key - JAX PRNG key |
| 66 | + | --- Cell Output Compartments: --- |
| 67 | + | s - emitted binary spikes/action potentials |
| 68 | + | tols - time-of-last-spike |
| 69 | +
|
| 70 | + | References: |
| 71 | + | Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to |
| 72 | + | conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. |
| 73 | + | |
| 74 | + | Kistler, Werner M., Wulfram Gerstner, and J. Leo van Hemmen. "Reduction of the Hodgkin-Huxley equations to a |
| 75 | + | single-variable threshold model." Neural computation 9.5 (1997): 1015-1045. |
| 76 | +
|
| 77 | + Args: |
| 78 | + name: the string name of this cell |
| 79 | +
|
| 80 | + n_units: number of cellular entities (neural population size) |
| 81 | +
|
| 82 | + tau_v: membrane time constant (Default: 1 ms) |
| 83 | +
|
| 84 | + resist_m: membrane resistance value |
| 85 | +
|
| 86 | + v_Na: sodium reversal potential |
| 87 | +
|
| 88 | + v_K: potassium reversal potential |
| 89 | +
|
| 90 | + v_L: leak reversal potential |
| 91 | +
|
| 92 | + g_Na: sodium (Na) conductance per unit area |
| 93 | +
|
| 94 | + g_K: potassium (K) conductance per unit area |
| 95 | +
|
| 96 | + g_L: leak conductance per unit area |
| 97 | +
|
| 98 | + thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes/pulses) |
| 99 | +
|
| 100 | + spike_reset: if True, once voltage crosses threshold, then dynamics of voltage and recovery are reset/snapped |
| 101 | + to `v_reset` which has a default value of 0 mV (Default: False) |
| 102 | +
|
| 103 | + v_reset: voltage value to reset to after a spike (in mV) |
| 104 | + (Default: 0 mV) |
| 105 | + """ |
| 106 | + |
| 107 | + # Define Functions |
| 108 | + def __init__( |
| 109 | + self, name, n_units, tau_v, resist_m=1., v_Na=115., v_K=-35., v_L=10.6, g_Na=100., g_K=5., g_L=0.3, thr=4., |
| 110 | + spike_reset=False, v_reset=0., **kwargs |
| 111 | + ): |
| 112 | + super().__init__(name, **kwargs) |
| 113 | + |
| 114 | + ## membrane parameter setup (affects ODE integration) |
| 115 | + self.tau_v = tau_v ## membrane time constant |
| 116 | + self.R_m = resist_m ## resistance value |
| 117 | + self.spike_reset = spike_reset |
| 118 | + self.thr = thr # mV ## base value for threshold |
| 119 | + self.v_reset = v_reset ## base value to reset voltage to (if spike_reset = True) |
| 120 | + self.v_Na = v_Na #115. ## ENa |
| 121 | + self.v_K = v_K #-35. #-12. ## EK |
| 122 | + self.v_L = v_L #10.6 ## EKleak |
| 123 | + self.g_Na = g_Na #100. #120. ## gNa |
| 124 | + self.g_K = g_K #5. #36. ## gK |
| 125 | + self.g_L = g_L #0.3 ## gKleak |
| 126 | + |
| 127 | + ## Layer Size Setup |
| 128 | + self.batch_size = 1 |
| 129 | + self.n_units = n_units |
| 130 | + |
| 131 | + ## Compartment setup |
| 132 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 133 | + self.j = Compartment(restVals, display_name="Electrical input current") |
| 134 | + self.v = Compartment(restVals, display_name="Membrane potential/voltage") |
| 135 | + self.n = Compartment(restVals, display_name="Potassium channel subunit activation (probability)") |
| 136 | + self.m = Compartment(restVals, display_name="Sodium channel subunit activation (probability)") |
| 137 | + self.h = Compartment(restVals, display_name="Sodium channel subunit inactivation (probability)") |
| 138 | + self.s = Compartment(restVals, display_name="Spike pulse") |
| 139 | + self.tols = Compartment(restVals, display_name="Time-of-last-spike") ## time-of-last-spike |
| 140 | + |
| 141 | + @transition(output_compartments=["v", "m", "n", "h", "s", "tols"]) |
| 142 | + @staticmethod |
| 143 | + def advance_state(t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols): |
| 144 | + _j = j * R_m |
| 145 | + alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v) |
| 146 | + ## integrate voltage / membrane potential |
| 147 | + _, _v = step_euler(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L)) |
| 148 | + ## next, integrate different channels |
| 149 | + _, _n = step_euler(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) |
| 150 | + _, _m = step_euler(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) |
| 151 | + _, _h = step_euler(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) |
| 152 | + ## obtain action potentials/spikes/pulses |
| 153 | + s = (_v > thr) * 1. |
| 154 | + if spike_reset: ## if spike-reset used, variables snapped back to initial conditions |
| 155 | + alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = ( |
| 156 | + _calc_biophysical_constants(v * 0 + v_reset)) |
| 157 | + _v = _v * (1. - s) + s * v_reset |
| 158 | + _n = _n * (1. - s) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v)) |
| 159 | + _m = _m * (1. - s) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v)) |
| 160 | + _h = _h * (1. - s) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v)) |
| 161 | + ## transition to new state of (system of) variables |
| 162 | + v = _v |
| 163 | + m = _m |
| 164 | + n = _n |
| 165 | + h = _h |
| 166 | + tols = (1. - s) * tols + (s * t) ## update tols |
| 167 | + |
| 168 | + return v, m, n, h, s, tols |
| 169 | + |
| 170 | + @transition(output_compartments=["j", "v", "m", "n", "h", "s", "tols"]) |
| 171 | + @staticmethod |
| 172 | + def reset(batch_size, n_units): |
| 173 | + restVals = jnp.zeros((batch_size, n_units)) |
| 174 | + v = restVals # + 0 |
| 175 | + alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v) |
| 176 | + j = restVals #+ 0 |
| 177 | + n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v) |
| 178 | + m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v) |
| 179 | + h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v) |
| 180 | + s = restVals #+ 0 |
| 181 | + tols = restVals #+ 0 |
| 182 | + return j, v, m, n, h, s, tols |
| 183 | + |
| 184 | + def save(self, directory, **kwargs): |
| 185 | + file_name = directory + "/" + self.name + ".npz" |
| 186 | + #jnp.savez(file_name, threshold=self.thr.value) |
| 187 | + |
| 188 | + def load(self, directory, seeded=False, **kwargs): |
| 189 | + file_name = directory + "/" + self.name + ".npz" |
| 190 | + data = jnp.load(file_name) |
| 191 | + #self.thr.set( data['threshold'] ) |
| 192 | + |
| 193 | + @classmethod |
| 194 | + def help(cls): ## component help function |
| 195 | + properties = { |
| 196 | + "cell_type": "WTASCell - evolves neurons according to winner-take-all " |
| 197 | + "spiking dynamics " |
| 198 | + } |
| 199 | + compartment_props = { |
| 200 | + "inputs": |
| 201 | + {"j": "External input electrical current"}, |
| 202 | + "states": |
| 203 | + {"v": "Membrane potential/voltage at time t", |
| 204 | + "n": "Current state of potassium channel subunit activation", |
| 205 | + "m": "Current state of sodium channel subunit activation", |
| 206 | + "h": "Current state of sodium channel subunit inactivation", |
| 207 | + "key": "JAX PRNG key"}, |
| 208 | + "outputs": |
| 209 | + {"s": "Emitted spikes/pulses at time t", |
| 210 | + "tols": "Time-of-last-spike"}, |
| 211 | + } |
| 212 | + hyperparams = { |
| 213 | + "n_units": "Number of neuronal cells to model in this layer", |
| 214 | + "tau_v": "Cell membrane time constant", |
| 215 | + "resist_m": "Membrane resistance value", |
| 216 | + "thr": "Base voltage threshold value", |
| 217 | + "v_Na": "Sodium reversal potential", |
| 218 | + "v_K": "Potassium reversal potential", |
| 219 | + "v_L": "Leak reversal potential", |
| 220 | + "g_Na": "Sodium conductance per unit area", |
| 221 | + "g_K": "Potassium conductance per unit area", |
| 222 | + "g_L": "Leak conductance per unit area", |
| 223 | + "spike_reset": "Should this cell hyperpolarize by snapping to base values or not?", |
| 224 | + "v_reset": "Voltage value to reset to after a spike" |
| 225 | + } |
| 226 | + info = {cls.__name__: properties, |
| 227 | + "compartments": compartment_props, |
| 228 | + "dynamics": "tau_v dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L)", |
| 229 | + "hyperparameters": hyperparams} |
| 230 | + return info |
| 231 | + |
| 232 | + def __repr__(self): |
| 233 | + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] |
| 234 | + maxlen = max(len(c) for c in comps) + 5 |
| 235 | + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" |
| 236 | + for c in comps: |
| 237 | + stats = tensorstats(getattr(self, c).value) |
| 238 | + if stats is not None: |
| 239 | + line = [f"{k}: {v}" for k, v in stats.items()] |
| 240 | + line = ", ".join(line) |
| 241 | + else: |
| 242 | + line = "None" |
| 243 | + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" |
| 244 | + return lines |
| 245 | + |
| 246 | +if __name__ == '__main__': |
| 247 | + from ngcsimlib.context import Context |
| 248 | + with Context("Bar") as bar: |
| 249 | + X = HodgkinHuxleyCell("X", 1, 1.) |
| 250 | + print(X) |
0 commit comments