|
591 | 591 | "name": "stderr", |
592 | 592 | "output_type": "stream", |
593 | 593 | "text": [ |
594 | | - "epoch 1: 100%|██████████| 611/611 [00:05<00:00, 115.33it/s, loss=0.743, metrics={'acc': 0.6205, 'prec': 0.2817}]\n", |
595 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 168.06it/s, loss=0.545, metrics={'acc': 0.6452, 'prec': 0.3014}]\n", |
596 | | - "epoch 2: 100%|██████████| 611/611 [00:04<00:00, 122.57it/s, loss=0.486, metrics={'acc': 0.7765, 'prec': 0.5517}]\n", |
597 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 158.84it/s, loss=0.44, metrics={'acc': 0.783, 'prec': 0.573}] \n", |
598 | | - "epoch 3: 100%|██████████| 611/611 [00:04<00:00, 124.89it/s, loss=0.419, metrics={'acc': 0.8129, 'prec': 0.6753}]\n", |
599 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 158.10it/s, loss=0.402, metrics={'acc': 0.815, 'prec': 0.6816}] \n", |
600 | | - "epoch 4: 100%|██████████| 611/611 [00:04<00:00, 126.35it/s, loss=0.393, metrics={'acc': 0.8228, 'prec': 0.7047}]\n", |
601 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 160.72it/s, loss=0.385, metrics={'acc': 0.8233, 'prec': 0.7024}]\n", |
602 | | - "epoch 5: 100%|██████████| 611/611 [00:04<00:00, 124.33it/s, loss=0.38, metrics={'acc': 0.826, 'prec': 0.702}] \n", |
603 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 163.43it/s, loss=0.376, metrics={'acc': 0.8264, 'prec': 0.7}] \n" |
| 594 | + "epoch 1: 100%|██████████| 611/611 [00:06<00:00, 101.71it/s, loss=0.448, metrics={'acc': 0.792, 'prec': 0.5728}] \n", |
| 595 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 171.00it/s, loss=0.366, metrics={'acc': 0.7991, 'prec': 0.5907}]\n", |
| 596 | + "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 101.69it/s, loss=0.361, metrics={'acc': 0.8324, 'prec': 0.6817}]\n", |
| 597 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 169.36it/s, loss=0.357, metrics={'acc': 0.8328, 'prec': 0.6807}]\n", |
| 598 | + "epoch 3: 100%|██████████| 611/611 [00:05<00:00, 102.65it/s, loss=0.352, metrics={'acc': 0.8366, 'prec': 0.691}] \n", |
| 599 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 171.49it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6867}]\n", |
| 600 | + "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 101.52it/s, loss=0.347, metrics={'acc': 0.8389, 'prec': 0.6956}]\n", |
| 601 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 163.49it/s, loss=0.349, metrics={'acc': 0.8383, 'prec': 0.6906}]\n", |
| 602 | + "epoch 5: 100%|██████████| 611/611 [00:07<00:00, 84.91it/s, loss=0.343, metrics={'acc': 0.8405, 'prec': 0.6987}] \n", |
| 603 | + "valid: 100%|██████████| 153/153 [00:01<00:00, 142.83it/s, loss=0.347, metrics={'acc': 0.8399, 'prec': 0.6946}]\n" |
604 | 604 | ] |
605 | 605 | } |
606 | 606 | ], |
|
664 | 664 | "name": "stderr", |
665 | 665 | "output_type": "stream", |
666 | 666 | "text": [ |
667 | | - "epoch 1: 100%|██████████| 611/611 [00:05<00:00, 108.62it/s, loss=0.894, metrics={'acc': 0.5182, 'prec': 0.2037}]\n", |
668 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 154.44it/s, loss=0.604, metrics={'acc': 0.5542, 'prec': 0.2135}]\n", |
669 | | - "epoch 2: 100%|██████████| 611/611 [00:05<00:00, 106.49it/s, loss=0.51, metrics={'acc': 0.751, 'prec': 0.4614}] \n", |
670 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 157.79it/s, loss=0.452, metrics={'acc': 0.7581, 'prec': 0.4898}]\n", |
671 | | - "epoch 3: 100%|██████████| 611/611 [00:05<00:00, 106.66it/s, loss=0.425, metrics={'acc': 0.8031, 'prec': 0.6618}]\n", |
672 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 160.73it/s, loss=0.405, metrics={'acc': 0.806, 'prec': 0.6686}] \n", |
673 | | - "epoch 4: 100%|██████████| 611/611 [00:05<00:00, 106.58it/s, loss=0.394, metrics={'acc': 0.8185, 'prec': 0.6966}]\n", |
674 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 155.55it/s, loss=0.385, metrics={'acc': 0.8196, 'prec': 0.6994}]\n", |
675 | | - "epoch 5: 100%|██████████| 611/611 [00:05<00:00, 107.28it/s, loss=0.38, metrics={'acc': 0.8236, 'prec': 0.7004}] \n", |
676 | | - "valid: 100%|██████████| 153/153 [00:00<00:00, 155.37it/s, loss=0.375, metrics={'acc': 0.8244, 'prec': 0.7017}]\n" |
| 667 | + "epoch 1: 100%|██████████| 611/611 [00:07<00:00, 77.46it/s, loss=0.387, metrics={'acc': 0.8192, 'prec': 0.6576}]\n", |
| 668 | + "valid: 100%|██████████| 153/153 [00:01<00:00, 147.78it/s, loss=0.36, metrics={'acc': 0.8216, 'prec': 0.6617}] \n", |
| 669 | + "epoch 2: 100%|██████████| 611/611 [00:08<00:00, 74.99it/s, loss=0.358, metrics={'acc': 0.8313, 'prec': 0.6836}]\n", |
| 670 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 158.26it/s, loss=0.355, metrics={'acc': 0.8321, 'prec': 0.6848}]\n", |
| 671 | + "epoch 3: 100%|██████████| 611/611 [00:08<00:00, 76.28it/s, loss=0.351, metrics={'acc': 0.8345, 'prec': 0.6889}]\n", |
| 672 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 154.84it/s, loss=0.354, metrics={'acc': 0.8347, 'prec': 0.6887}]\n", |
| 673 | + "epoch 4: 100%|██████████| 611/611 [00:07<00:00, 76.71it/s, loss=0.346, metrics={'acc': 0.8374, 'prec': 0.6946}]\n", |
| 674 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 157.80it/s, loss=0.353, metrics={'acc': 0.8369, 'prec': 0.6935}]\n", |
| 675 | + "epoch 5: 100%|██████████| 611/611 [00:08<00:00, 73.25it/s, loss=0.343, metrics={'acc': 0.8386, 'prec': 0.6966}]\n", |
| 676 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 157.05it/s, loss=0.352, metrics={'acc': 0.8382, 'prec': 0.6961}]\n" |
677 | 677 | ] |
678 | 678 | } |
679 | 679 | ], |
680 | 680 | "source": [ |
681 | 681 | "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)" |
682 | 682 | ] |
| 683 | + }, |
| 684 | + { |
| 685 | + "cell_type": "markdown", |
| 686 | + "metadata": {}, |
| 687 | + "source": [ |
| 688 | + "Also mentioning that one could build a model with the individual components independently. For example, a model comprised only by the `wide` component would be simply a linear model. This could be attained by just:" |
| 689 | + ] |
| 690 | + }, |
| 691 | + { |
| 692 | + "cell_type": "code", |
| 693 | + "execution_count": 15, |
| 694 | + "metadata": {}, |
| 695 | + "outputs": [], |
| 696 | + "source": [ |
| 697 | + "model = WideDeep(wide=wide)" |
| 698 | + ] |
| 699 | + }, |
| 700 | + { |
| 701 | + "cell_type": "code", |
| 702 | + "execution_count": 16, |
| 703 | + "metadata": {}, |
| 704 | + "outputs": [], |
| 705 | + "source": [ |
| 706 | + "model.compile(method='binary', metrics=[Accuracy, Precision])" |
| 707 | + ] |
| 708 | + }, |
| 709 | + { |
| 710 | + "cell_type": "code", |
| 711 | + "execution_count": 17, |
| 712 | + "metadata": {}, |
| 713 | + "outputs": [ |
| 714 | + { |
| 715 | + "name": "stderr", |
| 716 | + "output_type": "stream", |
| 717 | + "text": [ |
| 718 | + "\r", |
| 719 | + " 0%| | 0/611 [00:00<?, ?it/s]" |
| 720 | + ] |
| 721 | + }, |
| 722 | + { |
| 723 | + "name": "stdout", |
| 724 | + "output_type": "stream", |
| 725 | + "text": [ |
| 726 | + "Training\n" |
| 727 | + ] |
| 728 | + }, |
| 729 | + { |
| 730 | + "name": "stderr", |
| 731 | + "output_type": "stream", |
| 732 | + "text": [ |
| 733 | + "epoch 1: 100%|██████████| 611/611 [00:03<00:00, 188.59it/s, loss=0.482, metrics={'acc': 0.771, 'prec': 0.5633}] \n", |
| 734 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 236.13it/s, loss=0.423, metrics={'acc': 0.7747, 'prec': 0.5819}]\n", |
| 735 | + "epoch 2: 100%|██████████| 611/611 [00:03<00:00, 190.62it/s, loss=0.399, metrics={'acc': 0.8131, 'prec': 0.686}] \n", |
| 736 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 221.47it/s, loss=0.387, metrics={'acc': 0.8138, 'prec': 0.6879}]\n", |
| 737 | + "epoch 3: 100%|██████████| 611/611 [00:03<00:00, 190.28it/s, loss=0.378, metrics={'acc': 0.8267, 'prec': 0.7149}]\n", |
| 738 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 241.12it/s, loss=0.374, metrics={'acc': 0.8255, 'prec': 0.7128}]\n", |
| 739 | + "epoch 4: 100%|██████████| 611/611 [00:03<00:00, 183.27it/s, loss=0.37, metrics={'acc': 0.8304, 'prec': 0.7073}] \n", |
| 740 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 227.46it/s, loss=0.369, metrics={'acc': 0.8294, 'prec': 0.7061}]\n", |
| 741 | + "epoch 5: 100%|██████████| 611/611 [00:03<00:00, 184.28it/s, loss=0.366, metrics={'acc': 0.8315, 'prec': 0.7006}]\n", |
| 742 | + "valid: 100%|██████████| 153/153 [00:00<00:00, 239.87it/s, loss=0.366, metrics={'acc': 0.8303, 'prec': 0.6999}]\n" |
| 743 | + ] |
| 744 | + } |
| 745 | + ], |
| 746 | + "source": [ |
| 747 | + "model.fit(X_wide=X_wide, target=target, n_epochs=5, batch_size=64, val_split=0.2)" |
| 748 | + ] |
683 | 749 | } |
684 | 750 | ], |
685 | 751 | "metadata": { |
|
0 commit comments