|
11 | 11 | }, |
12 | 12 | { |
13 | 13 | "cell_type": "code", |
14 | | - "execution_count": 5, |
| 14 | + "execution_count": 1, |
15 | 15 | "metadata": {}, |
16 | 16 | "outputs": [], |
17 | 17 | "source": [ |
|
21 | 21 | "\n", |
22 | 22 | "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n", |
23 | 23 | "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n", |
24 | | - "from pytorch_widedeep.metrics import BinaryAccuracy" |
| 24 | + "from pytorch_widedeep.metrics import Accuracy, Precision" |
25 | 25 | ] |
26 | 26 | }, |
27 | 27 | { |
28 | 28 | "cell_type": "code", |
29 | | - "execution_count": 6, |
| 29 | + "execution_count": 2, |
30 | 30 | "metadata": {}, |
31 | 31 | "outputs": [ |
32 | 32 | { |
|
185 | 185 | "4 30 United-States <=50K " |
186 | 186 | ] |
187 | 187 | }, |
188 | | - "execution_count": 6, |
| 188 | + "execution_count": 2, |
189 | 189 | "metadata": {}, |
190 | 190 | "output_type": "execute_result" |
191 | 191 | } |
|
197 | 197 | }, |
198 | 198 | { |
199 | 199 | "cell_type": "code", |
200 | | - "execution_count": 7, |
| 200 | + "execution_count": 3, |
201 | 201 | "metadata": {}, |
202 | 202 | "outputs": [ |
203 | 203 | { |
|
356 | 356 | "4 30 United-States 0 " |
357 | 357 | ] |
358 | 358 | }, |
359 | | - "execution_count": 7, |
| 359 | + "execution_count": 3, |
360 | 360 | "metadata": {}, |
361 | 361 | "output_type": "execute_result" |
362 | 362 | } |
|
381 | 381 | }, |
382 | 382 | { |
383 | 383 | "cell_type": "code", |
384 | | - "execution_count": 8, |
| 384 | + "execution_count": 4, |
385 | 385 | "metadata": {}, |
386 | 386 | "outputs": [], |
387 | 387 | "source": [ |
|
394 | 394 | }, |
395 | 395 | { |
396 | 396 | "cell_type": "code", |
397 | | - "execution_count": 9, |
| 397 | + "execution_count": 5, |
398 | 398 | "metadata": {}, |
399 | 399 | "outputs": [], |
400 | 400 | "source": [ |
|
412 | 412 | }, |
413 | 413 | { |
414 | 414 | "cell_type": "code", |
415 | | - "execution_count": 10, |
| 415 | + "execution_count": 6, |
416 | 416 | "metadata": {}, |
417 | 417 | "outputs": [ |
418 | 418 | { |
|
437 | 437 | }, |
438 | 438 | { |
439 | 439 | "cell_type": "code", |
440 | | - "execution_count": 11, |
| 440 | + "execution_count": 7, |
441 | 441 | "metadata": {}, |
442 | 442 | "outputs": [ |
443 | 443 | { |
|
475 | 475 | }, |
476 | 476 | { |
477 | 477 | "cell_type": "code", |
478 | | - "execution_count": 14, |
| 478 | + "execution_count": 8, |
479 | 479 | "metadata": {}, |
480 | 480 | "outputs": [], |
481 | 481 | "source": [ |
|
489 | 489 | }, |
490 | 490 | { |
491 | 491 | "cell_type": "code", |
492 | | - "execution_count": 15, |
| 492 | + "execution_count": 9, |
493 | 493 | "metadata": {}, |
494 | 494 | "outputs": [ |
495 | 495 | { |
|
527 | 527 | ")" |
528 | 528 | ] |
529 | 529 | }, |
530 | | - "execution_count": 15, |
| 530 | + "execution_count": 9, |
531 | 531 | "metadata": {}, |
532 | 532 | "output_type": "execute_result" |
533 | 533 | } |
|
560 | 560 | }, |
561 | 561 | { |
562 | 562 | "cell_type": "code", |
563 | | - "execution_count": 16, |
| 563 | + "execution_count": 10, |
564 | 564 | "metadata": {}, |
565 | 565 | "outputs": [], |
566 | 566 | "source": [ |
567 | | - "model.compile(method='binary', metrics=[BinaryAccuracy])" |
| 567 | + "model.compile(method='binary', metrics=[Accuracy, Precision])" |
568 | 568 | ] |
569 | 569 | }, |
570 | 570 | { |
571 | 571 | "cell_type": "code", |
572 | | - "execution_count": 17, |
| 572 | + "execution_count": 11, |
573 | 573 | "metadata": {}, |
574 | 574 | "outputs": [ |
575 | 575 | { |
|
591 | 591 | "name": "stderr", |
592 | 592 | "output_type": "stream", |
593 | 593 | "text": [ |
594 | | - "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n", |
595 | | - "valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n", |
596 | | - "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n", |
597 | | - "valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n", |
598 | | - "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n", |
599 | | - "valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n", |
600 | | - "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n", |
601 | | - "valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n", |
602 | | - "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n", |
603 | | - "valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n" |
| 594 | + "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n", |
| 595 | + "valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n", |
| 596 | + "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n", |
| 597 | + "valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n", |
| 598 | + "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n", |
| 599 | + "valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n", |
| 600 | + "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n", |
| 601 | + "valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n", |
| 602 | + "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n", |
| 603 | + "valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n" |
604 | 604 | ] |
605 | 605 | } |
606 | 606 | ], |
|
0 commit comments