|
| 1 | +from typing import Optional, Any |
| 2 | + |
| 3 | +from brainpy import math as bm |
| 4 | +from brainpy._src.dynsys import Dynamic |
| 5 | +from brainpy._src.mixin import SupportAutoDelay |
| 6 | +from brainpy.types import Shape |
| 7 | + |
| 8 | +__all__ = [ |
| 9 | + 'InputVar', |
| 10 | +] |
| 11 | + |
| 12 | + |
| 13 | +class InputVar(Dynamic, SupportAutoDelay): |
| 14 | + """Define an input variable. |
| 15 | +
|
| 16 | + Example:: |
| 17 | +
|
| 18 | + import brainpy as bp |
| 19 | +
|
| 20 | +
|
| 21 | + class Exponential(bp.Projection): |
| 22 | + def __init__(self, pre, post, prob, g_max, tau, E=0.): |
| 23 | + super().__init__() |
| 24 | + self.proj = bp.dyn.ProjAlignPostMg2( |
| 25 | + pre=pre, |
| 26 | + delay=None, |
| 27 | + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), |
| 28 | + syn=bp.dyn.Expon.desc(post.num, tau=tau), |
| 29 | + out=bp.dyn.COBA.desc(E=E), |
| 30 | + post=post, |
| 31 | + ) |
| 32 | +
|
| 33 | +
|
| 34 | + class EINet(bp.DynSysGroup): |
| 35 | + def __init__(self, num_exc, num_inh, method='exp_auto'): |
| 36 | + super(EINet, self).__init__() |
| 37 | +
|
| 38 | + # neurons |
| 39 | + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., |
| 40 | + V_initializer=bp.init.Normal(-55., 2.), method=method) |
| 41 | + self.E = bp.dyn.LifRef(num_exc, **pars) |
| 42 | + self.I = bp.dyn.LifRef(num_inh, **pars) |
| 43 | +
|
| 44 | + # synapses |
| 45 | + w_e = 0.6 # excitatory synaptic weight |
| 46 | + w_i = 6.7 # inhibitory synaptic weight |
| 47 | +
|
| 48 | + # Neurons connect to each other randomly with a connection probability of 2% |
| 49 | + self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) |
| 50 | + self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) |
| 51 | + self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) |
| 52 | + self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) |
| 53 | +
|
| 54 | + # define input variables given to E/I populations |
| 55 | + self.Ein = bp.dyn.InputVar(self.E.varshape) |
| 56 | + self.Iin = bp.dyn.InputVar(self.I.varshape) |
| 57 | + self.E.add_inp_fun('', self.Ein) |
| 58 | + self.I.add_inp_fun('', self.Iin) |
| 59 | +
|
| 60 | +
|
| 61 | + net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method |
| 62 | + runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) |
| 63 | + runner.run(100.) |
| 64 | +
|
| 65 | + # visualization |
| 66 | + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], |
| 67 | + title='Spikes of Excitatory Neurons', show=True) |
| 68 | + bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], |
| 69 | + title='Spikes of Inhibitory Neurons', show=True) |
| 70 | +
|
| 71 | +
|
| 72 | + """ |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + size: Shape, |
| 76 | + keep_size: bool = False, |
| 77 | + sharding: Optional[Any] = None, |
| 78 | + name: Optional[str] = None, |
| 79 | + mode: Optional[bm.Mode] = None, |
| 80 | + method: str = 'exp_auto' |
| 81 | + ): |
| 82 | + super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) |
| 83 | + |
| 84 | + self.reset_state(self.mode) |
| 85 | + |
| 86 | + def reset_state(self, batch_or_mode=None): |
| 87 | + self.input = self.init_variable(bm.zeros, batch_or_mode) |
| 88 | + |
| 89 | + def update(self, *args, **kwargs): |
| 90 | + return self.input.value |
| 91 | + |
| 92 | + def return_info(self): |
| 93 | + return self.input |
| 94 | + |
| 95 | + def clear_input(self, *args, **kwargs): |
| 96 | + self.reset_state(self.mode) |
0 commit comments