|
40 | 40 | ]
|
41 | 41 | },
|
42 | 42 | {
|
| 43 | + "attachments": {}, |
43 | 44 | "cell_type": "markdown",
|
44 | 45 | "metadata": {
|
45 | 46 | "id": "IgYKebt871EK"
|
|
59 | 60 | ]
|
60 | 61 | },
|
61 | 62 | {
|
| 63 | + "attachments": {}, |
62 | 64 | "cell_type": "markdown",
|
63 | 65 | "metadata": {
|
64 | 66 | "id": "6JTRoM7E71EU"
|
|
95 | 97 | ]
|
96 | 98 | },
|
97 | 99 | {
|
| 100 | + "attachments": {}, |
98 | 101 | "cell_type": "markdown",
|
99 | 102 | "metadata": {
|
100 | 103 | "id": "6VKVqLb371EV"
|
|
130 | 133 | ]
|
131 | 134 | },
|
132 | 135 | {
|
| 136 | + "attachments": {}, |
133 | 137 | "cell_type": "markdown",
|
134 | 138 | "metadata": {
|
135 | 139 | "id": "cREmhMWJ71EX"
|
|
143 | 147 | ]
|
144 | 148 | },
|
145 | 149 | {
|
| 150 | + "attachments": {}, |
146 | 151 | "cell_type": "markdown",
|
147 | 152 | "metadata": {
|
148 | 153 | "id": "1NhotGiT71EY"
|
|
199 | 204 | ]
|
200 | 205 | },
|
201 | 206 | {
|
| 207 | + "attachments": {}, |
202 | 208 | "cell_type": "markdown",
|
203 | 209 | "metadata": {
|
204 | 210 | "id": "LgTG6buf71Ea"
|
|
256 | 262 | ]
|
257 | 263 | },
|
258 | 264 | {
|
| 265 | + "attachments": {}, |
259 | 266 | "cell_type": "markdown",
|
260 | 267 | "metadata": {
|
261 | 268 | "id": "SzFGcrhv71Ed"
|
|
304 | 311 | "test_imgs = test_loader.get_all_faces()\n",
|
305 | 312 | "\n",
|
306 | 313 | "# Call the Capsa-wrapped classifier to generate outputs: predictions, uncertainty, and bias!\n",
|
307 |
| - "predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)" |
| 314 | + "#predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)\n", |
| 315 | + "out = wrapped_model.predict(test_imgs, batch_size=512)\n" |
308 | 316 | ]
|
309 | 317 | },
|
310 | 318 | {
|
| 319 | + "attachments": {}, |
311 | 320 | "cell_type": "markdown",
|
312 | 321 | "metadata": {
|
313 | 322 | "id": "629ng-_H6WOk"
|
|
329 | 338 | "### Analyzing representation bias scores ###\n",
|
330 | 339 | "\n",
|
331 | 340 | "# Sort according to lowest to highest representation scores\n",
|
332 |
| - "indices = np.argsort(bias, axis=None) # sort the score values themselves\n", |
| 341 | + "indices = np.argsort(out.bias, axis=None) # sort the score values themselves\n", |
333 | 342 | "sorted_images = test_imgs[indices] # sort images from lowest to highest representations\n",
|
334 |
| - "sorted_biases = bias[indices] # order the representation bias scores\n", |
335 |
| - "sorted_preds = predictions[indices] # order the prediction values\n", |
| 343 | + "sorted_biases = out.bias.numpy()[indices] # order the representation bias scores\n", |
| 344 | + "sorted_preds = out.y_hat.numpy()[indices] # order the prediction values\n", |
336 | 345 | "\n",
|
337 | 346 | "\n",
|
338 | 347 | "# Visualize the 20 images with the lowest and highest representation in the test dataset\n",
|
|
345 | 354 | ]
|
346 | 355 | },
|
347 | 356 | {
|
| 357 | + "attachments": {}, |
348 | 358 | "cell_type": "markdown",
|
349 | 359 | "metadata": {
|
350 | 360 | "id": "-JYmGMJF71Ef"
|
|
368 | 378 | ]
|
369 | 379 | },
|
370 | 380 | {
|
| 381 | + "attachments": {}, |
371 | 382 | "cell_type": "markdown",
|
372 | 383 | "metadata": {
|
373 | 384 | "id": "i8ERzg2-71Ef"
|
|
389 | 400 | ]
|
390 | 401 | },
|
391 | 402 | {
|
| 403 | + "attachments": {}, |
392 | 404 | "cell_type": "markdown",
|
393 | 405 | "metadata": {
|
394 | 406 | "id": "cRNV-3SU71Eg"
|
|
404 | 416 | ]
|
405 | 417 | },
|
406 | 418 | {
|
| 419 | + "attachments": {}, |
407 | 420 | "cell_type": "markdown",
|
408 | 421 | "metadata": {
|
409 | 422 | "id": "ww5lx7ue71Eg"
|
|
420 | 433 | ]
|
421 | 434 | },
|
422 | 435 | {
|
| 436 | + "attachments": {}, |
423 | 437 | "cell_type": "markdown",
|
424 | 438 | "metadata": {
|
425 | 439 | "id": "NEfeWo2p7wKm"
|
|
442 | 456 | "### Analyzing epistemic uncertainty estimates ###\n",
|
443 | 457 | "\n",
|
444 | 458 | "# Sort according to epistemic uncertainty estimates\n",
|
445 |
| - "epistemic_indices = np.argsort(uncertainty, axis=None) # sort the uncertainty values\n", |
| 459 | + "epistemic_indices = np.argsort(out.epistemic, axis=None) # sort the uncertainty values\n", |
446 | 460 | "epistemic_images = test_imgs[epistemic_indices] # sort images from lowest to highest uncertainty\n",
|
447 |
| - "sorted_epistemic = uncertainty[epistemic_indices] # order the uncertainty scores\n", |
448 |
| - "sorted_epistemic_preds = predictions[epistemic_indices] # order the prediction values\n", |
| 461 | + "sorted_epistemic = out.epistemic.numpy()[epistemic_indices] # order the uncertainty scores\n", |
| 462 | + "sorted_epistemic_preds = out.y_hat.numpy()[epistemic_indices] # order the prediction values\n", |
449 | 463 | "\n",
|
450 | 464 | "\n",
|
451 | 465 | "# Visualize the 20 images with the LEAST and MOST epistemic uncertainty\n",
|
|
458 | 472 | ]
|
459 | 473 | },
|
460 | 474 | {
|
| 475 | + "attachments": {}, |
461 | 476 | "cell_type": "markdown",
|
462 | 477 | "metadata": {
|
463 | 478 | "id": "L0dA8EyX71Eh"
|
|
481 | 496 | ]
|
482 | 497 | },
|
483 | 498 | {
|
| 499 | + "attachments": {}, |
484 | 500 | "cell_type": "markdown",
|
485 | 501 | "metadata": {
|
486 | 502 | "id": "iyn0IE6x71Eh"
|
|
496 | 512 | ]
|
497 | 513 | },
|
498 | 514 | {
|
| 515 | + "attachments": {}, |
499 | 516 | "cell_type": "markdown",
|
500 | 517 | "metadata": {
|
501 | 518 | "id": "XbwRbesM71Eh"
|
|
561 | 578 | "\n",
|
562 | 579 | " # After the epoch is done, recompute data sampling proabilities \n",
|
563 | 580 | " # according to the inverse of the bias\n",
|
564 |
| - " pred, unc, bias = wrapper(train_imgs)\n", |
| 581 | + " out = wrapper(train_imgs)\n", |
565 | 582 | "\n",
|
566 | 583 | " # Increase the probability of sampling under-represented datapoints by setting \n",
|
567 | 584 | " # the probability to the **inverse** of the biases\n",
|
568 |
| - " inverse_bias = 1.0 / (bias.numpy() + 1e-7)\n", |
| 585 | + " inverse_bias = 1.0 / (np.mean(out.bias.numpy(),axis=-1) + 1e-7)\n", |
569 | 586 | "\n",
|
570 | 587 | " # Normalize the inverse biases in order to convert them to probabilities\n",
|
571 | 588 | " p_faces = inverse_bias / np.sum(inverse_bias)\n",
|
|
575 | 592 | ]
|
576 | 593 | },
|
577 | 594 | {
|
| 595 | + "attachments": {}, |
578 | 596 | "cell_type": "markdown",
|
579 | 597 | "metadata": {
|
580 | 598 | "id": "SwXrAeBo71Ej"
|
|
598 | 616 | "### Evaluation of debiased model ###\n",
|
599 | 617 | "\n",
|
600 | 618 | "# Get classification predictions, uncertainties, and representation bias scores\n",
|
601 |
| - "pred, unc, bias = wrapper.predict(test_imgs)\n", |
| 619 | + "out = wrapper.predict(test_imgs)\n", |
602 | 620 | "\n",
|
603 | 621 | "# Sort according to lowest to highest representation scores\n",
|
604 |
| - "indices = np.argsort(bias, axis=None)\n", |
| 622 | + "indices = np.argsort(out.bias, axis=None)\n", |
605 | 623 | "bias_images = test_imgs[indices] # sort the images\n",
|
606 |
| - "sorted_bias = bias[indices] # sort the representation bias scores\n", |
607 |
| - "sorted_bias_preds = pred[indices] # sort the predictions\n", |
| 624 | + "sorted_bias = out.bias.numpy()[indices] # sort the representation bias scores\n", |
| 625 | + "sorted_bias_preds = out.y_hat.numpy()[indices] # sort the predictions\n", |
608 | 626 | "\n",
|
609 | 627 | "# Plot the representation bias vs. the accuracy\n",
|
610 | 628 | "plt.xlabel(\"Density (Representation)\")\n",
|
|
613 | 631 | ]
|
614 | 632 | },
|
615 | 633 | {
|
| 634 | + "attachments": {}, |
616 | 635 | "cell_type": "markdown",
|
617 | 636 | "metadata": {
|
618 | 637 | "id": "d1cEEnII71Ej"
|
|
681 | 700 | "name": "python",
|
682 | 701 | "nbconvert_exporter": "python",
|
683 | 702 | "pygments_lexer": "ipython3",
|
684 |
| - "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" |
| 703 | + "version": "3.9.16" |
685 | 704 | },
|
686 | 705 | "vscode": {
|
687 | 706 | "interpreter": {
|
|
0 commit comments