@@ -6,7 +6,285 @@ model originally proposed in (Rao & Ballard, 1999) [1].
66The model code for this exhibit can be found
77[ here] ( https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon ) .
88
9+
10+
11+
12+ ## Predictive Coding with NGC-Learn
13+ -----------------------------------------------------------------------------------------------------
14+
15+ ### PC model for Reconstruction Task
16+
17+ For building PC model you first need to define all the components inside the model.
18+ Then you should wire those components together with specific configuration depending
19+ on the task.
20+
21+ 1 . ** Create neural component**
22+ 2 . ** Create synaptic component**
23+ 3 . ** Wire components** – define how the components connect and interact with each others.
24+
25+
26+ -----------------------------------------------------------------------------------------------------
27+
28+ <!-- ################################################################################ -->
29+
30+ ### 1- Make Neural component:
31+
32+ <!-- ################################################################################ -->
33+
34+
35+ ** Responding Neurons**
36+ <br >
37+
38+ We want to build a hierarchical neural network we need neural layers. In predictive coding network with real-valued dynamics
39+ we use ` RateCell ` components ([ RateCell tutorial] ( https://ngc-learn.readthedocs.io/en/latest/tutorials/neurocog/rate_cell.html ) ).
40+ Here, we want 3-layer network (3-hidden layers) so we define 3 components, each with ` n_units ` size for hidden representatins.
41+
42+ ``` python
43+ z3 = RateCell(" z3" , n_units = h3_dim, tau_m = tau_m, act_fx = act_fx, prior = (prior_type, lmbda))
44+ z2 = RateCell(" z2" , n_units = h2_dim, tau_m = tau_m, act_fx = act_fx, prior = (prior_type, lmbda))
45+ z1 = RateCell(" z1" , n_units = h1_dim, tau_m = tau_m, act_fx = act_fx, prior = (prior_type, lmbda))
46+ ```
47+
48+
49+
50+
51+ <!-- ################################################################################ -->
52+
53+ <br >
54+ <br >
55+
56+ <img src =" ../images/museum/hgpc/GEC.png " width =" 120 " align =" right " />
57+
58+ ** Error Neurons**
59+ <br >
60+
61+
62+ For each activation layer we have a set of additional neurons with the same size to measure the prediction error for individual
63+ ` RateCell ` components. The error value will later be used to calculate the ** energy** for layers (including hiddens) and the whole model.
64+
65+
66+ ``` python
67+ e2 = GaussianErrorCell(" e2" , n_units = h2_dim) # # e2_size == z2_size
68+ e1 = GaussianErrorCell(" e1" , n_units = h1_dim) # # e1_size == z1_size
69+ e0 = GaussianErrorCell(" e0" , n_units = in_dim) # # e0_size == z0_size (x size)
70+ ```
71+
72+
73+ <br >
74+ <br >
75+
76+ <!-- ################################################################################ -->
77+
78+ ### 2- Make Synaptic component:
79+
80+ <!-- ################################################################################ -->
81+
82+ <br >
83+ <br >
84+
85+ <!-- <img src="images/GEC.png" width="120" align="right"/> -->
86+
87+ ** Forward Synapses**
88+ <br >
89+
90+ To connect layers to each others we create synapstic components. To send infromation in forward pass (from input into deeper layers with a bottom-up stream)
91+ we use ` ForwardSynapse ` components. Check out [ Brain's Information Flow] ( https://github.com/Faezehabibi/pc_tutorial/blob/main/information_flow.md#---information-flow-in-the-brain-- )
92+ for detailed explanation of information flow in brain modeling.
93+
94+
95+ ``` python
96+ E3 = ForwardSynapse(" E3" , shape = (h2_dim, h3_dim)) # # pre-layer size (x) => (h1) post-layer size
97+ E2 = ForwardSynapse(" E2" , shape = (h1_dim, h2_dim)) # # pre-layer size (h1) => (h2) post-layer size
98+ E1 = ForwardSynapse(" E1" , shape = (in_dim, h1_dim)) # # pre-layer size (h2) => (h3) post-layer size
99+ ```
100+
101+ <!-- ################################################################################ -->
102+
103+ <br >
104+ <br >
105+
106+ <!-- <img src="images/GEC.png" width="120" align="right"/> -->
107+
108+ ** Backward Synapses**
109+ <br >
110+
111+ For each ` ForwardSynapse ` components sending infromation upward (bottom-up stream) exist a ` BackwardSynapse ` component to reverse the information flow and
112+ send it back downward (top-down stream -- from top layer to bottom/input). If you are not convinced, check out [ Information Flow] ( https://github.com/Faezehabibi/pc_tutorial/blob/19b0692fa307f2b06676ca93b9b93ba3ba854766/information_flow.md ) .
113+
114+ ``` python
115+ W3 = BackwardSynapse(" W3" ,
116+ shape = (h3_dim, h2_dim), # # pre-layer size (h3) => (h2) post-layer size
117+ optim_type = opt_type, # # optimization method (sgd, adam, ...)
118+ weight_init = w3_init, # # W3[t0]: initial values before training at time[t0]
119+ w_bound = w_bound, # # -1 for deactivating the bouding synaptic value
120+ sign_value = - 1 ., # # -1 means M-step solve minimization problem
121+ eta = eta, # # learning-rate (lr)
122+ )
123+ W2 = BackwardSynapse(" W2" ,
124+ shape = (h2_dim, h1_dim), # # pre-layer size (h2) => (h1) post-layer size
125+ optim_type = opt_type, # # Optimizer
126+ weight_init = w2_init, # # W2[t0]
127+ w_bound = w_bound, # # -1: deactivate the bouding
128+ sign_value = - 1 ., # # Minimization
129+ eta = eta, # # lr
130+ )
131+ W1 = BackwardSynapse(" W1" ,
132+ shape = (h1_dim, in_dim), # # pre-layer size (h1) => (x) post-layer size
133+ optim_type = opt_type, # # Optimizer
134+ weight_init = w1_init, # # W1[t0]
135+ w_bound = w_bound, # # -1: deactivate the bouding
136+ sign_value = - 1 ., # # Minimization
137+ eta = eta, # # lr
138+ )
139+ ```
140+
141+
142+
143+
144+
145+
146+ <br >
147+ <br >
148+ <!-- ----------------------------------------------------------------------------------------------------- -->
149+
150+ ### Wire Component:
151+
152+
153+ The signal pathway is according to Rao & Ballard 1999.
154+ Error is information goes from buttom to up in the forward pass.
155+ Corrected prediction comes back from top to the down in the backward pass.
156+
157+
158+ ``` python
159+ # ######## feedback (Top-down) #########
160+ # ## actual neural activation
161+ e2.target << z2.z
162+ e1.target << z1.z
163+
164+ # ## Top-down prediction
165+ e2.mu << W3.outputs
166+ e1.mu << W2.outputs
167+ e0.mu << W1.outputs
168+
169+ # ## Top-down prediction errors
170+ z1.j_td << e1.dtarget
171+ z2.j_td << e2.dtarget
172+
173+ W3.inputs << z3.zF
174+ W2.inputs << z2.zF
175+ W1.inputs << z1.zF
176+ ```
177+
178+
179+ ``` python
180+ # ######## forward (Bottom-up) #########
181+ # # feedforward the errors via synapses
182+ E3.inputs << e2.dmu
183+ E2.inputs << e1.dmu
184+ E1.inputs << e0.dmu
185+
186+ # # Bottom-up modulated errors
187+ z3.j << E3.outputs
188+ z2.j << E2.outputs
189+ z1.j << E1.outputs
190+ ```
191+
192+
193+ ``` python
194+ # ####### Hebbian learning #########
195+ # # Pre Synaptic Activation
196+ W3.pre << z3.zF
197+ W2.pre << z2.zF
198+ W1.pre << z1.zF
199+
200+ # # Post Synaptic residual error
201+ W3.post << e2.dmu
202+ W2.post << e1.dmu
203+ W1.post << e0.dmu
204+ ```
205+
206+
207+
208+
209+ <br >
210+ <br >
211+ <!-- ----------------------------------------------------------------------------------------------------- -->
212+
213+ ##### Process Dynamics:
214+
215+
216+ ``` python
217+ # ######## Process #########
218+
219+ # ########## reset/set all components to their resting values / initial conditions
220+ circuit.reset()
221+
222+ circuit.clamp_input(obs) # # clamp the signal to the lowest layer activation
223+ z0.z.set(obs) # # or directly put obs in e0.target.set(obs)
224+
225+ # ########## pin/tie feedback synapses to transpose of forward ones
226+ E1.weights.set(jnp.transpose(W1.weights.value))
227+ E2.weights.set(jnp.transpose(W2.weights.value))
228+ E3.weights.set(jnp.transpose(W3.weights.value))
229+
230+ circuit.process(jnp.array([[dt * i, dt] for i in range (T)])) # # Perform several E-steps
231+
232+ circuit.evolve(t = T, dt = 1 .) # # Perform M-step (scheduled synaptic updates)
233+
234+ obs_mu = e0.mu.value # # get reconstructed signal
235+ L0 = e0.L.value # # calculate reconstruction loss
236+ ```
237+
238+
239+
240+
241+ <br >
242+ <br >
243+ <br >
244+ <br >
245+ <!-- ----------------------------------------------------------------------------------------------------- -->
246+ <!-- ----------------------------------------------------------------------------------------------------- -->
247+
248+
249+ ### Train PC model for reconstructing the patched image
250+
251+ <img src =" ../images/museum/hgpc/patch_input.png " width =" 300 " align =" right " />
252+
253+ <br >
254+
255+ This time, the input image is not the full scene while it is locally patched. This changes the processing
256+ units among the network where local features are now important. The original models in Rao & ballard 1999
257+ are also in patch format where similar to retina the processing units are localized. This also is results in
258+ similar filters or receptive fields as in convolutional neural networks (CNNs).
259+
260+
261+ <br >
262+
263+ ``` python
264+ for nb in range (n_batches):
265+ Xb = X[nb * images_per_batch: (nb + 1 ) * images_per_batch, :] # # shape: (mb_size, 784)
266+ Xb = generate_patch_set(Xb, patch_shape, center = True )
267+
268+ Xmu, Lb = model.process(Xb)
269+ ```
270+
271+
272+
273+
274+
275+
276+ <!-- -------------------------------------------------------------------------------------
277+ ### Train PC model for reconstructing the full image
278+
279+ ```python
280+ for nb in range(n_batches):
281+ Xb = X[nb * mb_size: (nb + 1) * mb_size, :] ## shape: (mb_size, 784)
282+ Xmu, Lb = model.process(Xb)
283+ ```
284+ ------------------------------------------------------------------------------------- -->
285+
286+
9287<!-- references -->
10288## References
11289<b >[ 1] </b > Rao, Rajesh PN, and Dana H. Ballard. "Predictive coding in the visual cortex: a functional interpretation of
12- some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87.
290+ some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87.
0 commit comments