|
2 | 2 | from ngcsimlib.context import Context |
3 | 3 | import numpy as np |
4 | 4 | np.random.seed(42) |
5 | | -from ngclearn.components import IFCell |
6 | | -from ngcsimlib.compilers import compile_command, wrap_command |
7 | | -from numpy.testing import assert_array_equal |
8 | 5 |
|
9 | | -from ngcsimlib.compilers.process import Process, transition |
10 | | -from ngcsimlib.component import Component |
11 | | -from ngcsimlib.compartment import Compartment |
12 | | -from ngcsimlib.context import Context |
13 | | -from ngcsimlib.utils.compartment import Get_Compartment_Batch |
| 6 | +from ngclearn import Context, MethodProcess |
| 7 | +from ngclearn.components.neurons.spiking.IFCell import IFCell |
| 8 | +from numpy.testing import assert_array_equal |
14 | 9 |
|
15 | 10 | def test_IFCell1(): |
16 | 11 | name = "if_ctx" |
17 | 12 | ## create seeding keys |
18 | 13 | dkey = random.PRNGKey(1234) |
19 | 14 | dkey, *subkeys = random.split(dkey, 6) |
20 | 15 | dt = 1. # ms |
21 | | - trace_increment = 0.1 |
22 | 16 | # ---- build a simple Poisson cell system ---- |
23 | 17 | with Context(name) as ctx: |
24 | 18 | a = IFCell( |
25 | 19 | name="a", n_units=1, tau_m=5., resist_m=10., key=subkeys[0] |
26 | 20 | ) |
27 | 21 |
|
28 | | - #""" |
29 | | - advance_process = (Process("advance_proc") |
| 22 | + # """ |
| 23 | + advance_process = (MethodProcess("advance_proc") |
30 | 24 | >> a.advance_state) |
31 | | - #ctx.wrap_and_add_command(advance_process.pure, name="run") |
32 | | - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") |
| 25 | + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") |
33 | 26 |
|
34 | | - reset_process = (Process("reset_proc") |
| 27 | + reset_process = (MethodProcess("reset_proc") |
35 | 28 | >> a.reset) |
36 | | - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") |
37 | | - #""" |
38 | | - |
39 | | - """ |
40 | | - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") |
41 | | - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") |
42 | | - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") |
43 | | - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") |
44 | | - """ |
45 | | - |
| 29 | + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") |
| 30 | + # """ |
46 | 31 | ## set up non-compiled utility commands |
47 | | - @Context.dynamicCommand |
48 | | - def clamp(x): |
49 | | - a.j.set(x) |
| 32 | + # @Context.dynamicCommand |
| 33 | + # def clamp(x): |
| 34 | + # a.j.set(x) |
| 35 | + |
| 36 | + def clamp(x): |
| 37 | + a.j.set(x) |
50 | 38 |
|
51 | 39 | ## input spike train |
52 | 40 | x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32) |
53 | 41 | ## desired output/epsp pulses |
54 | 42 | y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32) |
55 | 43 |
|
56 | 44 | outs = [] |
57 | | - ctx.reset() |
| 45 | + reset_process.run() # ctx.reset() |
58 | 46 | for ts in range(x_seq.shape[1]): |
59 | 47 | x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t |
60 | | - ctx.clamp(x_t) |
61 | | - ctx.run(t=ts * 1., dt=dt) |
62 | | - outs.append(a.s.value) |
| 48 | + clamp(x_t) # ctx.clamp(x_t) |
| 49 | + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) |
| 50 | + outs.append(a.s.get()) |
63 | 51 | outs = jnp.concatenate(outs, axis=1) |
64 | | - print(outs) |
65 | | - |
| 52 | + # print(outs) |
| 53 | + # print(y_seq) |
| 54 | + |
66 | 55 | ## output should equal input |
67 | 56 | assert_array_equal(outs, y_seq) |
68 | 57 |
|
69 | | -#test_IFCell1() |
| 58 | +test_IFCell1() |
0 commit comments