Skip to content

Commit 40ac018

Browse files
author
Alexander Ororbia
committed
mod to pc-rao doc
1 parent d21e1a5 commit 40ac018

File tree

1 file changed

+75
-22
lines changed

1 file changed

+75
-22
lines changed

docs/museum/pc_rao_ballard1999.md

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ model originally proposed in (Rao & Ballard, 1999) [1].
66
The model code for this exhibit can be found
77
[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon).
88

9-
109
## Setting Up Hierarchical Predictive Coding (HPC) with NGC-Learn
1110

12-
1311
### The HPC Model for Reconstruction Tasks
1412

1513
To build an HPC model, you will first need to define all of the components inside of the model.
@@ -210,28 +208,81 @@ Finally, to enable learning, we will need to set up simple 2-term/factor Hebbian
210208

211209
#### Specifying the HPC Model's Process Dynamics:
212210

211+
The only remaining thing to do for the above model is to specify its core simulation functions
212+
(known in NGC-Learn as `MethodProcess` mechanisms). For an HPC model, we want to make sure
213+
we define how it's full message-passing is carried out as well as how learning (synaptic plasticity)
214+
occurs. Ultimately, this will follow the (dynamic) expectation-maximization (E-M) scheme we have
215+
discussed in other model exhibits, e.g., the [sparse coding and dictionary learning exhibit](sparse_coding.md).
216+
217+
The method-processes for inference (expectation) and adaptation (maximization) can be written out under
218+
your model context as follows:
219+
220+
```python
221+
reset_process = (MethodProcess(name="reset_process") ## reset-to-baseline
222+
>> z3.reset
223+
>> z2.reset
224+
>> z1.reset
225+
>> e2.reset
226+
>> e1.reset
227+
>> e0.reset
228+
>> W3.reset
229+
>> W2.reset
230+
>> W1.reset
231+
>> E3.reset
232+
>> E2.reset
233+
>> E1.reset)
234+
advance_process = (MethodProcess(name="advance_process") ## E-step
235+
>> E1.advance_state
236+
>> E2.advance_state
237+
>> E3.advance_state
238+
>> z3.advance_state
239+
>> z2.advance_state
240+
>> z1.advance_state
241+
>> W3.advance_state
242+
>> W2.advance_state
243+
>> W1.advance_state
244+
>> e2.advance_state
245+
>> e1.advance_state
246+
>> e0.advance_state)
247+
evolve_process = (MethodProcess(name="evolve_process") ## M-step
248+
>> W1.evolve
249+
>> W2.evolve
250+
>> W3.evolve)
251+
```
252+
253+
Below we show a code-snippet depicting how the HPC model's ability to process a stimulus input
254+
(or batch of inputs) `obs` -- or observation -- is carried out in practice:
213255

214256
```python
215-
######### Process #########
257+
######### Process #########
216258

217-
########### reset/set all components to their resting values / initial conditions
218-
circuit.reset()
259+
#### reset/set all neuronal components to their resting values / initial conditions
260+
circuit.reset.run()
219261

220-
circuit.clamp_input(obs) ## clamp the signal to the lowest layer activation
221-
z0.z.set(obs) ## or directly put obs in e0.target.set(obs)
262+
#### clamp the observation/signal obs to the lowest layer activation
263+
e0.target.set(obs) ## e0 contains the place where our stimulus target goes
222264

223-
########### pin/tie feedback synapses to transpose of forward ones
224-
E1.weights.set(jnp.transpose(W1.weights.value))
225-
E2.weights.set(jnp.transpose(W2.weights.value))
226-
E3.weights.set(jnp.transpose(W3.weights.value))
227-
228-
circuit.process(jnp.array([[dt * i, dt] for i in range(T)])) ## Perform several E-steps
229-
circuit.evolve(t=T, dt=1.) ## Perform M-step (scheduled synaptic updates)
265+
#### pin/tie feedback synapses to transpose of forward ones
266+
E1.weights.set(jnp.transpose(W1.weights.value))
267+
E2.weights.set(jnp.transpose(W2.weights.value))
268+
E3.weights.set(jnp.transpose(W3.weights.value))
269+
270+
#### apply the dynamic E-M algorithm on the HPC model given obs
271+
inputs = jnp.array(self.advance_proc.pack_rows(T, t=lambda x: x, dt=dt))
272+
stateManager.state, outputs = self.process.scan(inputs) ## Perform several (T) E-steps
273+
circuit.evolve.run(t=T, dt=1.) ## Perform M-step (scheduled synaptic updates)
230274

231-
obs_mu = e0.mu.value ## get reconstructed signal
232-
L0 = e0.L.value ## calculate reconstruction loss
275+
#### extract some statistics for downstream analysis
276+
obs_mu = e0.mu.value ## get reconstructed signal
277+
L0 = e0.L.value ## calculate reconstruction loss
278+
free_energy = e0.L.value + e1.L.value + e2.L.value ## F = Sum_l Sum_j [e^l_j]^2
233279
```
234280

281+
Note that we make use of NGC-Learn's backend state-manager (`ngcsimlib.global_state.StateManager`) to
282+
roll-out the `T` E-steps carried out above efficiently (and effectively using JAX's scan utilities;
283+
see the NGC-Learn configuration documents, such as the one related to the
284+
[global state](../tutorials/configuration/global_state.md) for more information).
285+
235286
<br>
236287
<br>
237288
<br>
@@ -240,17 +291,19 @@ Finally, to enable learning, we will need to set up simple 2-term/factor Hebbian
240291
<!-- ----------------------------------------------------------------------------------------------------- -->
241292

242293

243-
### Train the PC model for Reconstructing the "Patched" Image
294+
### Train the PC model for Reconstructing Image Patches
244295

245296
<img src="../images/museum/hgpc/patch_input.png" width="300" align="right"/>
246297

247298
<br>
248299

249-
This time, the input image is not the full scene while it is locally patched. This changes the processing
250-
units among the network where local features are now important. The original models in Rao & ballard 1999
251-
are also in patch format where similar to retina the processing units are localized. This also is results in
252-
similar filters or receptive fields as in convolutional neural networks (CNNs).
253-
300+
In this scenario, the input image is not the full scene (or complete set of pixels that fully describe an image);
301+
instead, the input is locally "patched", which means that is has been broken down into smaller $K \times K$
302+
blocks/grids. This input patch exctraction scheme changes the information processing of the neuronal
303+
units within the network, i.e., local features are now important. The original model(s) of Rao and Ballard's
304+
1999 work <b>[1]</b> are also in a patched format, modeling how retinal processing units are localized in nature.
305+
Setting up the input stimulus in this manner also results in models that acquire filters (or receptive fields)
306+
similar to those acquired in convolutional neural networks (CNNs).
254307

255308
<br>
256309

0 commit comments

Comments
 (0)