Skip to content

Commit 043b0a8

Browse files
author
Alexander Ororbia
committed
revised adex and error-cell neurocog tutorials
1 parent c51b83c commit 043b0a8

File tree

2 files changed

+25
-48
lines changed

2 files changed

+25
-48
lines changed

docs/tutorials/neurocog/adex_cell.md

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ for ts in range(T):
108108
x_t = data
109109
## pass in t and dt and run step forward of simulation
110110
clamp(x_t)
111-
advance_process.run(t=t, dt=dt) #
111+
advance_process.run(t=t, dt=dt) # run one step of dynamics
112112
t = t + dt
113113

114114
## naively extract simple statistics at time ts and print them to I/O
@@ -143,26 +143,27 @@ recov_rec = np.squeeze(np.asarray(recov_rec))
143143
spk_rec = np.squeeze(np.asarray(spk_rec))
144144

145145
# Plot the AdEx cell trajectory
146-
cell_tag = "RS"
147146
n_plots = 1
148147
fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5))
149148
ax_ptr = ax
150-
ax_ptr.set(xlabel='Time', ylabel='Voltage (v)',
151-
title="AdEx ({}) Voltage Dynamics".format(cell_tag))
149+
ax_ptr.set(
150+
xlabel='Time', ylabel='Voltage (v)', title="AdEx Voltage Dynamics"
151+
)
152152

153153
v = ax_ptr.plot(time_span, mem_rec, color='C0')
154154
ax_ptr.legend([v[0]],['v'])
155155
plt.tight_layout()
156-
plt.savefig("{0}".format("adex_v_plot.jpg".format(cell_tag.lower())))
156+
plt.savefig("{0}".format("adex_v_plot.jpg"))
157157

158158
fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5))
159159
ax_ptr = ax
160-
ax_ptr.set(xlabel='Time', ylabel='Recovery (w)',
161-
title="AdEx ({}) Recovery Dynamics".format(cell_tag))
160+
ax_ptr.set(
161+
xlabel='Time', ylabel='Recovery (w)', title="AdEx Recovery Dynamics"
162+
)
162163
w = ax_ptr.plot(time_span, recov_rec, color='C1', alpha=.5)
163164
ax_ptr.legend([w[0]],['w'])
164165
plt.tight_layout()
165-
plt.savefig("{0}".format("adex_w_plot.jpg".format(cell_tag.lower())))
166+
plt.savefig("{0}".format("adex_w_plot.jpg"))
166167
plt.close()
167168
```
168169

@@ -187,27 +188,6 @@ however, one could configure it to use the midpoint method for integration
187188
by setting its argument `integration_type = rk2` in cases where more
188189
accuracy in the dynamics is needed (at the cost of additional computational time).
189190

190-
## Optional: Setting Up The Components with a JSON Configuration
191-
192-
While you are not required to create a JSON configuration file for ngc-learn,
193-
to get rid of the warning that ngc-learn will throw at the start of your
194-
program's execution (indicating that you do not have a configuration set up yet),
195-
all you need to do is create a sub-directory for your JSON configuration
196-
inside of your project code's directory, i.e., `json_files/modules.json`.
197-
Inside the JSON file, you would write the following:
198-
199-
```json
200-
[
201-
{"absolute_path": "ngclearn.components",
202-
"attributes": [
203-
{"name": "AdExCell"}]
204-
},
205-
{"absolute_path": "ngcsimlib.operations",
206-
"attributes": [
207-
{"name": "overwrite"}]
208-
}
209-
]
210-
```
211191

212192
## References
213193

docs/tutorials/neurocog/error_cell.md

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ The code you would write amounts to the below:
6060

6161
```python
6262
from jax import numpy as jnp, jit
63-
from ngcsimlib.context import Context
64-
from ngclearn.utils import JaxProcess
63+
64+
from ngclearn import Context, MethodProcess
6565
## import model-specific mechanisms
6666
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
6767

@@ -71,32 +71,29 @@ T = 5 ## number time steps to simulate
7171
with Context("Model") as model:
7272
cell = GaussianErrorCell("z0", n_units=3)
7373

74-
advance_process = (JaxProcess()
74+
advance_process = (MethodProcess("advance_proc")
7575
>> cell.advance_state)
76-
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
77-
78-
reset_process = (JaxProcess()
76+
reset_process = (MethodProcess("reset_proc")
7977
>> cell.reset)
80-
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
81-
8278

83-
@Context.dynamicCommand
84-
def clamp(x, y):
85-
## error cells have two key input compartments; a "mu" and a "target"
86-
cell.mu.set(x)
87-
cell.target.set(y)
79+
## set up non-compiled utility commands
80+
def clamp(x, y):
81+
## error cells have two key input compartments; a "mu" and a "target"
82+
cell.mu.set(x)
83+
cell.target.set(y)
84+
8885

8986
guess = jnp.asarray([[-1., 1., 1.]], jnp.float32) ## the produced guess or prediction
9087
answer = jnp.asarray([[1., -1., 1.]], jnp.float32) ## what we wish the guess had been
9188

92-
model.reset()
89+
reset_process.run()
9390
for ts in range(T):
94-
model.clamp(guess, answer)
95-
model.advance(t=ts * 1., dt=dt)
91+
clamp(guess, answer)
92+
advance_process.run(t=ts * 1., dt=dt)
9693
## extract compartment values of interest
97-
dmu = cell.dmu.value
98-
dtarget = cell.dtarget.value
99-
loss = cell.L.value
94+
dmu = cell.dmu.get()
95+
dtarget = cell.dtarget.get()
96+
loss = cell.L.get()
10097
## print compartment values to I/O
10198
print("{} | dmu: {} dtarget: {} loss: {} ".format(ts, dmu, dtarget, loss))
10299
```

0 commit comments

Comments
 (0)