@@ -6,10 +6,8 @@ model originally proposed in (Rao & Ballard, 1999) [1].
66The model code for this exhibit can be found
77[ here] ( https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon ) .
88
9-
109## Setting Up Hierarchical Predictive Coding (HPC) with NGC-Learn
1110
12-
1311### The HPC Model for Reconstruction Tasks
1412
1513To build an HPC model, you will first need to define all of the components inside of the model.
@@ -210,28 +208,81 @@ Finally, to enable learning, we will need to set up simple 2-term/factor Hebbian
210208
211209#### Specifying the HPC Model's Process Dynamics:
212210
211+ The only remaining thing to do for the above model is to specify its core simulation functions
212+ (known in NGC-Learn as ` MethodProcess ` mechanisms). For an HPC model, we want to make sure
213+ we define how it's full message-passing is carried out as well as how learning (synaptic plasticity)
214+ occurs. Ultimately, this will follow the (dynamic) expectation-maximization (E-M) scheme we have
215+ discussed in other model exhibits, e.g., the [ sparse coding and dictionary learning exhibit] ( sparse_coding.md ) .
216+
217+ The method-processes for inference (expectation) and adaptation (maximization) can be written out under
218+ your model context as follows:
219+
220+ ``` python
221+ reset_process = (MethodProcess(name = " reset_process" ) # # reset-to-baseline
222+ >> z3.reset
223+ >> z2.reset
224+ >> z1.reset
225+ >> e2.reset
226+ >> e1.reset
227+ >> e0.reset
228+ >> W3.reset
229+ >> W2.reset
230+ >> W1.reset
231+ >> E3.reset
232+ >> E2.reset
233+ >> E1.reset)
234+ advance_process = (MethodProcess(name = " advance_process" ) # # E-step
235+ >> E1.advance_state
236+ >> E2.advance_state
237+ >> E3.advance_state
238+ >> z3.advance_state
239+ >> z2.advance_state
240+ >> z1.advance_state
241+ >> W3.advance_state
242+ >> W2.advance_state
243+ >> W1.advance_state
244+ >> e2.advance_state
245+ >> e1.advance_state
246+ >> e0.advance_state)
247+ evolve_process = (MethodProcess(name = " evolve_process" ) # # M-step
248+ >> W1.evolve
249+ >> W2.evolve
250+ >> W3.evolve)
251+ ```
252+
253+ Below we show a code-snippet depicting how the HPC model's ability to process a stimulus input
254+ (or batch of inputs) ` obs ` -- or observation -- is carried out in practice:
213255
214256``` python
215- # ######## Process #########
257+ # ######## Process #########
216258
217- # ########## reset/set all components to their resting values / initial conditions
218- circuit.reset()
259+ # ### reset/set all neuronal components to their resting values / initial conditions
260+ circuit.reset.run()
219261
220- circuit.clamp_input(obs) # # clamp the signal to the lowest layer activation
221- z0.z .set(obs) # # or directly put obs in e0. target.set(obs)
262+ # ### clamp the observation/ signal obs to the lowest layer activation
263+ e0.target .set(obs) # # e0 contains the place where our stimulus target goes
222264
223- # ########## pin/tie feedback synapses to transpose of forward ones
224- E1.weights.set(jnp.transpose(W1.weights.value))
225- E2.weights.set(jnp.transpose(W2.weights.value))
226- E3.weights.set(jnp.transpose(W3.weights.value))
227-
228- circuit.process(jnp.array([[dt * i, dt] for i in range (T)])) # # Perform several E-steps
229- circuit.evolve(t = T, dt = 1 .) # # Perform M-step (scheduled synaptic updates)
265+ # ### pin/tie feedback synapses to transpose of forward ones
266+ E1.weights.set(jnp.transpose(W1.weights.value))
267+ E2.weights.set(jnp.transpose(W2.weights.value))
268+ E3.weights.set(jnp.transpose(W3.weights.value))
269+
270+ # ### apply the dynamic E-M algorithm on the HPC model given obs
271+ inputs = jnp.array(self .advance_proc.pack_rows(T, t = lambda x : x, dt = dt))
272+ stateManager.state, outputs = self .process.scan(inputs) # # Perform several (T) E-steps
273+ circuit.evolve.run(t = T, dt = 1 .) # # Perform M-step (scheduled synaptic updates)
230274
231- obs_mu = e0.mu.value # # get reconstructed signal
232- L0 = e0.L.value # # calculate reconstruction loss
275+ # ### extract some statistics for downstream analysis
276+ obs_mu = e0.mu.value # # get reconstructed signal
277+ L0 = e0.L.value # # calculate reconstruction loss
278+ free_energy = e0.L.value + e1.L.value + e2.L.value # # F = Sum_l Sum_j [e^l_j]^2
233279```
234280
281+ Note that we make use of NGC-Learn's backend state-manager (` ngcsimlib.global_state.StateManager ` ) to
282+ roll-out the ` T ` E-steps carried out above efficiently (and effectively using JAX's scan utilities;
283+ see the NGC-Learn configuration documents, such as the one related to the
284+ [ global state] ( ../tutorials/configuration/global_state.md ) for more information).
285+
235286<br >
236287<br >
237288<br >
@@ -240,17 +291,19 @@ Finally, to enable learning, we will need to set up simple 2-term/factor Hebbian
240291<!-- ----------------------------------------------------------------------------------------------------- -->
241292
242293
243- ### Train the PC model for Reconstructing the "Patched" Image
294+ ### Train the PC model for Reconstructing Image Patches
244295
245296<img src =" ../images/museum/hgpc/patch_input.png " width =" 300 " align =" right " />
246297
247298<br >
248299
249- This time, the input image is not the full scene while it is locally patched. This changes the processing
250- units among the network where local features are now important. The original models in Rao & ballard 1999
251- are also in patch format where similar to retina the processing units are localized. This also is results in
252- similar filters or receptive fields as in convolutional neural networks (CNNs).
253-
300+ In this scenario, the input image is not the full scene (or complete set of pixels that fully describe an image);
301+ instead, the input is locally "patched", which means that is has been broken down into smaller $K \times K$
302+ blocks/grids. This input patch exctraction scheme changes the information processing of the neuronal
303+ units within the network, i.e., local features are now important. The original model(s) of Rao and Ballard's
304+ 1999 work <b >[ 1] </b > are also in a patched format, modeling how retinal processing units are localized in nature.
305+ Setting up the input stimulus in this manner also results in models that acquire filters (or receptive fields)
306+ similar to those acquired in convolutional neural networks (CNNs).
254307
255308<br >
256309
0 commit comments