Skip to content

Commit a95b334

Browse files
Fix Minibatch alignment in Bayesian Neural Network example
1 parent 498eef5 commit a95b334

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

examples/variational_inference/bayesian_neural_network_advi.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
},
191191
"outputs": [],
192192
"source": [
193-
"def construct_nn(ann_input, ann_output):\n",
193+
"def construct_nn():\n",
194194
" n_hidden = 5\n",
195195
"\n",
196196
" # Initialize random weights between each layer\n",
@@ -204,9 +204,14 @@
204204
" \"train_cols\": np.arange(X_train.shape[1]),\n",
205205
" \"obs_id\": np.arange(X_train.shape[0]),\n",
206206
" }\n",
207+
" \n",
207208
" with pm.Model(coords=coords) as neural_network:\n",
208-
" ann_input = pm.Data(\"ann_input\", X_train, dims=(\"obs_id\", \"train_cols\"))\n",
209-
" ann_output = pm.Data(\"ann_output\", Y_train, dims=\"obs_id\")\n",
209+
" # Define minibatch variables\n",
210+
" minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
211+
" \n",
212+
" # Define data variables using minibatches\n",
213+
" ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
214+
" ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
210215
"\n",
211216
" # Weights from input to hidden layer\n",
212217
" weights_in_1 = pm.Normal(\n",
@@ -231,13 +236,13 @@
231236
" \"out\",\n",
232237
" act_out,\n",
233238
" observed=ann_output,\n",
234-
" total_size=Y_train.shape[0], # IMPORTANT for minibatches\n",
239+
" total_size=X_train.shape[0], # IMPORTANT for minibatches\n",
235240
" dims=\"obs_id\",\n",
236241
" )\n",
237242
" return neural_network\n",
238243
"\n",
239-
"\n",
240-
"neural_network = construct_nn(X_train, Y_train)"
244+
"# Create the neural network model\n",
245+
"neural_network = construct_nn()\n"
241246
]
242247
},
243248
{

0 commit comments

Comments
 (0)