Skip to content

Commit 48503dc

Browse files
committed
Update notebook to work with new version of capsa
1 parent 3d7b511 commit 48503dc

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

lab3/Part2_BiasAndUncertainty.ipynb

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
]
4141
},
4242
{
43+
"attachments": {},
4344
"cell_type": "markdown",
4445
"metadata": {
4546
"id": "IgYKebt871EK"
@@ -59,6 +60,7 @@
5960
]
6061
},
6162
{
63+
"attachments": {},
6264
"cell_type": "markdown",
6365
"metadata": {
6466
"id": "6JTRoM7E71EU"
@@ -95,6 +97,7 @@
9597
]
9698
},
9799
{
100+
"attachments": {},
98101
"cell_type": "markdown",
99102
"metadata": {
100103
"id": "6VKVqLb371EV"
@@ -130,6 +133,7 @@
130133
]
131134
},
132135
{
136+
"attachments": {},
133137
"cell_type": "markdown",
134138
"metadata": {
135139
"id": "cREmhMWJ71EX"
@@ -143,6 +147,7 @@
143147
]
144148
},
145149
{
150+
"attachments": {},
146151
"cell_type": "markdown",
147152
"metadata": {
148153
"id": "1NhotGiT71EY"
@@ -199,6 +204,7 @@
199204
]
200205
},
201206
{
207+
"attachments": {},
202208
"cell_type": "markdown",
203209
"metadata": {
204210
"id": "LgTG6buf71Ea"
@@ -256,6 +262,7 @@
256262
]
257263
},
258264
{
265+
"attachments": {},
259266
"cell_type": "markdown",
260267
"metadata": {
261268
"id": "SzFGcrhv71Ed"
@@ -304,10 +311,12 @@
304311
"test_imgs = test_loader.get_all_faces()\n",
305312
"\n",
306313
"# Call the Capsa-wrapped classifier to generate outputs: predictions, uncertainty, and bias!\n",
307-
"predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)"
314+
"#predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)\n",
315+
"out = wrapped_model.predict(test_imgs, batch_size=512)\n"
308316
]
309317
},
310318
{
319+
"attachments": {},
311320
"cell_type": "markdown",
312321
"metadata": {
313322
"id": "629ng-_H6WOk"
@@ -329,10 +338,10 @@
329338
"### Analyzing representation bias scores ###\n",
330339
"\n",
331340
"# Sort according to lowest to highest representation scores\n",
332-
"indices = np.argsort(bias, axis=None) # sort the score values themselves\n",
341+
"indices = np.argsort(out.bias, axis=None) # sort the score values themselves\n",
333342
"sorted_images = test_imgs[indices] # sort images from lowest to highest representations\n",
334-
"sorted_biases = bias[indices] # order the representation bias scores\n",
335-
"sorted_preds = predictions[indices] # order the prediction values\n",
343+
"sorted_biases = out.bias.numpy()[indices] # order the representation bias scores\n",
344+
"sorted_preds = out.y_hat.numpy()[indices] # order the prediction values\n",
336345
"\n",
337346
"\n",
338347
"# Visualize the 20 images with the lowest and highest representation in the test dataset\n",
@@ -345,6 +354,7 @@
345354
]
346355
},
347356
{
357+
"attachments": {},
348358
"cell_type": "markdown",
349359
"metadata": {
350360
"id": "-JYmGMJF71Ef"
@@ -368,6 +378,7 @@
368378
]
369379
},
370380
{
381+
"attachments": {},
371382
"cell_type": "markdown",
372383
"metadata": {
373384
"id": "i8ERzg2-71Ef"
@@ -389,6 +400,7 @@
389400
]
390401
},
391402
{
403+
"attachments": {},
392404
"cell_type": "markdown",
393405
"metadata": {
394406
"id": "cRNV-3SU71Eg"
@@ -404,6 +416,7 @@
404416
]
405417
},
406418
{
419+
"attachments": {},
407420
"cell_type": "markdown",
408421
"metadata": {
409422
"id": "ww5lx7ue71Eg"
@@ -420,6 +433,7 @@
420433
]
421434
},
422435
{
436+
"attachments": {},
423437
"cell_type": "markdown",
424438
"metadata": {
425439
"id": "NEfeWo2p7wKm"
@@ -442,10 +456,10 @@
442456
"### Analyzing epistemic uncertainty estimates ###\n",
443457
"\n",
444458
"# Sort according to epistemic uncertainty estimates\n",
445-
"epistemic_indices = np.argsort(uncertainty, axis=None) # sort the uncertainty values\n",
459+
"epistemic_indices = np.argsort(out.epistemic, axis=None) # sort the uncertainty values\n",
446460
"epistemic_images = test_imgs[epistemic_indices] # sort images from lowest to highest uncertainty\n",
447-
"sorted_epistemic = uncertainty[epistemic_indices] # order the uncertainty scores\n",
448-
"sorted_epistemic_preds = predictions[epistemic_indices] # order the prediction values\n",
461+
"sorted_epistemic = out.epistemic.numpy()[epistemic_indices] # order the uncertainty scores\n",
462+
"sorted_epistemic_preds = out.y_hat.numpy()[epistemic_indices] # order the prediction values\n",
449463
"\n",
450464
"\n",
451465
"# Visualize the 20 images with the LEAST and MOST epistemic uncertainty\n",
@@ -458,6 +472,7 @@
458472
]
459473
},
460474
{
475+
"attachments": {},
461476
"cell_type": "markdown",
462477
"metadata": {
463478
"id": "L0dA8EyX71Eh"
@@ -481,6 +496,7 @@
481496
]
482497
},
483498
{
499+
"attachments": {},
484500
"cell_type": "markdown",
485501
"metadata": {
486502
"id": "iyn0IE6x71Eh"
@@ -496,6 +512,7 @@
496512
]
497513
},
498514
{
515+
"attachments": {},
499516
"cell_type": "markdown",
500517
"metadata": {
501518
"id": "XbwRbesM71Eh"
@@ -561,11 +578,11 @@
561578
"\n",
562579
" # After the epoch is done, recompute data sampling proabilities \n",
563580
" # according to the inverse of the bias\n",
564-
" pred, unc, bias = wrapper(train_imgs)\n",
581+
" out = wrapper(train_imgs)\n",
565582
"\n",
566583
" # Increase the probability of sampling under-represented datapoints by setting \n",
567584
" # the probability to the **inverse** of the biases\n",
568-
" inverse_bias = 1.0 / (bias.numpy() + 1e-7)\n",
585+
" inverse_bias = 1.0 / (np.mean(out.bias.numpy(),axis=-1) + 1e-7)\n",
569586
"\n",
570587
" # Normalize the inverse biases in order to convert them to probabilities\n",
571588
" p_faces = inverse_bias / np.sum(inverse_bias)\n",
@@ -575,6 +592,7 @@
575592
]
576593
},
577594
{
595+
"attachments": {},
578596
"cell_type": "markdown",
579597
"metadata": {
580598
"id": "SwXrAeBo71Ej"
@@ -598,13 +616,13 @@
598616
"### Evaluation of debiased model ###\n",
599617
"\n",
600618
"# Get classification predictions, uncertainties, and representation bias scores\n",
601-
"pred, unc, bias = wrapper.predict(test_imgs)\n",
619+
"out = wrapper.predict(test_imgs)\n",
602620
"\n",
603621
"# Sort according to lowest to highest representation scores\n",
604-
"indices = np.argsort(bias, axis=None)\n",
622+
"indices = np.argsort(out.bias, axis=None)\n",
605623
"bias_images = test_imgs[indices] # sort the images\n",
606-
"sorted_bias = bias[indices] # sort the representation bias scores\n",
607-
"sorted_bias_preds = pred[indices] # sort the predictions\n",
624+
"sorted_bias = out.bias.numpy()[indices] # sort the representation bias scores\n",
625+
"sorted_bias_preds = out.y_hat.numpy()[indices] # sort the predictions\n",
608626
"\n",
609627
"# Plot the representation bias vs. the accuracy\n",
610628
"plt.xlabel(\"Density (Representation)\")\n",
@@ -613,6 +631,7 @@
613631
]
614632
},
615633
{
634+
"attachments": {},
616635
"cell_type": "markdown",
617636
"metadata": {
618637
"id": "d1cEEnII71Ej"
@@ -681,7 +700,7 @@
681700
"name": "python",
682701
"nbconvert_exporter": "python",
683702
"pygments_lexer": "ipython3",
684-
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
703+
"version": "3.9.16"
685704
},
686705
"vscode": {
687706
"interpreter": {

0 commit comments

Comments
 (0)