Skip to content

Commit 65465a4

Browse files
authored
Merge pull request #17 from jrzaurin/precision_recall
Precision recall
2 parents 31c2d8e + 393ea43 commit 65465a4

21 files changed

+573
-308
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ final output neuron or neurons, depending on whether we are performing a
4040
binary classification or regression, or a multi-class classification. The
4141
components within the faded-pink rectangles are concatenated.
4242

43-
In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as:
43+
In math terms, and following the notation in the
44+
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
45+
as:
4446

4547
<p align="center">
4648
<img width="500" src="docs/figures/architecture_1_math.png">
@@ -130,7 +132,7 @@ from sklearn.model_selection import train_test_split
130132

131133
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
132134
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
133-
from pytorch_widedeep.metrics import BinaryAccuracy
135+
from pytorch_widedeep.metrics import Accuracy
134136

135137
# these next 4 lines are not directly related to pytorch-widedeep. I assume
136138
# you have downloaded the dataset and place it in a dir called data/adult/
@@ -178,7 +180,7 @@ deepdense = DeepDense(
178180

179181
# build, compile and fit
180182
model = WideDeep(wide=wide, deepdense=deepdense)
181-
model.compile(method="binary", metrics=[BinaryAccuracy])
183+
model.compile(method="binary", metrics=[Accuracy])
182184
model.fit(
183185
X_wide=X_wide,
184186
X_deep=X_deep,

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.1
1+
0.4.2

code_style.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# sort imports
2-
isort --recursive . pytorch_widedeep tests examples setup.py
2+
isort . pytorch_widedeep tests examples setup.py
33
# Black code style
44
black . pytorch_widedeep tests examples setup.py
55
# flake8 standards

docs/examples.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
pytorch-widedeep Examples
2+
*****************************
3+
4+
This section provides links to example notebooks that may be helpful to better
5+
understand the functionalities withing ``pytorch-widedeep`` and how to use
6+
them to address different problems
7+
8+
* `Preprocessors and Utils <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/01_Preprocessors_and_utils.ipynb>`__
9+
* `Model Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/02_Model_Components.ipynb>`__
10+
* `Binary Classification with default parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/03_Binary_Classification_with_Defaults.ipynb>`__
11+
* `Binary Classification with varying parameters <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/04_Binary_Classification_Varying_Parameters.ipynb>`__
12+
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
13+
* `Warm up routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_WarmUp_Model_Components.ipynb>`__

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Documentation
1919
Preprocessing <preprocessing>
2020
Model Components <model_components>
2121
Wide and Deep Models <wide_deep/index>
22+
Examples <examples>
2223

2324

2425
Introduction

docs/quick_start.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Prepare the wide and deep columns
3030
3131
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
3232
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
33-
from pytorch_widedeep.metrics import BinaryAccuracy
33+
from pytorch_widedeep.metrics import Accuracy
3434
3535
# prepare wide, crossed, embedding and continuous columns
3636
wide_cols = [
@@ -83,7 +83,7 @@ Build, compile, fit and predict
8383
8484
# build, compile and fit
8585
model = WideDeep(wide=wide, deepdense=deepdense)
86-
model.compile(method="binary", metrics=[BinaryAccuracy])
86+
model.compile(method="binary", metrics=[Accuracy])
8787
model.fit(
8888
X_wide=X_wide,
8989
X_deep=X_deep,

docs/wide_deep/metrics.rst

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
Metrics
22
=======
33

4-
.. autoclass:: pytorch_widedeep.metrics.BinaryAccuracy
4+
.. autoclass:: pytorch_widedeep.metrics.Accuracy
55
:members:
66
:undoc-members:
77
:show-inheritance:
88

9-
.. autoclass:: pytorch_widedeep.metrics.CategoricalAccuracy
9+
.. autoclass:: pytorch_widedeep.metrics.Precision
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
.. autoclass:: pytorch_widedeep.metrics.Recall
15+
:members:
16+
:undoc-members:
17+
:show-inheritance:
18+
19+
.. autoclass:: pytorch_widedeep.metrics.FBetaScore
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:
23+
24+
.. autoclass:: pytorch_widedeep.metrics.F1Score
1025
:members:
1126
:undoc-members:
1227
:show-inheritance:

examples/01_Preprocessors_and_utils.ipynb

Lines changed: 40 additions & 40 deletions
Large diffs are not rendered by default.

examples/02_Model_Components.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@
170170
{
171171
"data": {
172172
"text/plain": [
173-
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
174-
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
175-
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
176-
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
177-
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
173+
"tensor([[-0.0000, -0.9949, 3.8273, 0.0000, -1.3889, -2.9641, 0.0000, -0.0000],\n",
174+
" [ 3.9123, -0.0000, -0.0000, 1.9555, -1.3561, 1.7069, -0.0000, 0.9275],\n",
175+
" [-0.0000, -0.0000, 0.0000, -0.0000, 0.0000, -1.6489, -0.0000, -1.4985],\n",
176+
" [-1.2736, 0.0000, -1.2819, 2.1232, 0.0000, 2.2767, -0.0000, 3.5354],\n",
177+
" [-0.1726, -0.0000, -1.3275, -0.0000, -1.3703, 0.0000, -0.0000, -1.4637]],\n",
178178
" grad_fn=<MulBackward0>)"
179179
]
180180
},
@@ -484,10 +484,10 @@
484484
{
485485
"data": {
486486
"text/plain": [
487-
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
488-
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
489-
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
490-
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
487+
"tensor([[-2.2825e-03, -8.3100e-04, -8.8423e-04, -1.1084e-04, 8.8529e-02,\n",
488+
" -5.1577e-04, 2.8343e-01, -1.7071e-03],\n",
489+
" [-1.8486e-03, -8.5602e-04, -1.8552e-03, 3.6481e-01, 9.0812e-02,\n",
490+
" -9.6603e-04, 3.9017e-01, -2.6355e-03]], grad_fn=<LeakyReluBackward1>)"
491491
]
492492
},
493493
"execution_count": 18,

examples/03_Binary_Classification_with_Defaults.ipynb

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 5,
14+
"execution_count": 1,
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
@@ -21,12 +21,12 @@
2121
"\n",
2222
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
2323
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
24-
"from pytorch_widedeep.metrics import BinaryAccuracy"
24+
"from pytorch_widedeep.metrics import Accuracy, Precision"
2525
]
2626
},
2727
{
2828
"cell_type": "code",
29-
"execution_count": 6,
29+
"execution_count": 2,
3030
"metadata": {},
3131
"outputs": [
3232
{
@@ -185,7 +185,7 @@
185185
"4 30 United-States <=50K "
186186
]
187187
},
188-
"execution_count": 6,
188+
"execution_count": 2,
189189
"metadata": {},
190190
"output_type": "execute_result"
191191
}
@@ -197,7 +197,7 @@
197197
},
198198
{
199199
"cell_type": "code",
200-
"execution_count": 7,
200+
"execution_count": 3,
201201
"metadata": {},
202202
"outputs": [
203203
{
@@ -356,7 +356,7 @@
356356
"4 30 United-States 0 "
357357
]
358358
},
359-
"execution_count": 7,
359+
"execution_count": 3,
360360
"metadata": {},
361361
"output_type": "execute_result"
362362
}
@@ -381,7 +381,7 @@
381381
},
382382
{
383383
"cell_type": "code",
384-
"execution_count": 8,
384+
"execution_count": 4,
385385
"metadata": {},
386386
"outputs": [],
387387
"source": [
@@ -394,7 +394,7 @@
394394
},
395395
{
396396
"cell_type": "code",
397-
"execution_count": 9,
397+
"execution_count": 5,
398398
"metadata": {},
399399
"outputs": [],
400400
"source": [
@@ -412,7 +412,7 @@
412412
},
413413
{
414414
"cell_type": "code",
415-
"execution_count": 10,
415+
"execution_count": 6,
416416
"metadata": {},
417417
"outputs": [
418418
{
@@ -437,7 +437,7 @@
437437
},
438438
{
439439
"cell_type": "code",
440-
"execution_count": 11,
440+
"execution_count": 7,
441441
"metadata": {},
442442
"outputs": [
443443
{
@@ -475,7 +475,7 @@
475475
},
476476
{
477477
"cell_type": "code",
478-
"execution_count": 14,
478+
"execution_count": 8,
479479
"metadata": {},
480480
"outputs": [],
481481
"source": [
@@ -489,7 +489,7 @@
489489
},
490490
{
491491
"cell_type": "code",
492-
"execution_count": 15,
492+
"execution_count": 9,
493493
"metadata": {},
494494
"outputs": [
495495
{
@@ -527,7 +527,7 @@
527527
")"
528528
]
529529
},
530-
"execution_count": 15,
530+
"execution_count": 9,
531531
"metadata": {},
532532
"output_type": "execute_result"
533533
}
@@ -560,16 +560,16 @@
560560
},
561561
{
562562
"cell_type": "code",
563-
"execution_count": 16,
563+
"execution_count": 10,
564564
"metadata": {},
565565
"outputs": [],
566566
"source": [
567-
"model.compile(method='binary', metrics=[BinaryAccuracy])"
567+
"model.compile(method='binary', metrics=[Accuracy, Precision])"
568568
]
569569
},
570570
{
571571
"cell_type": "code",
572-
"execution_count": 17,
572+
"execution_count": 11,
573573
"metadata": {},
574574
"outputs": [
575575
{
@@ -591,16 +591,16 @@
591591
"name": "stderr",
592592
"output_type": "stream",
593593
"text": [
594-
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
595-
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
596-
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
597-
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
598-
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
599-
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
600-
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
601-
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
602-
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
603-
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
594+
"epoch 1: 100%|██████████| 153/153 [00:01<00:00, 102.41it/s, loss=0.585, metrics={'acc': 0.7512, 'prec': 0.1818}]\n",
595+
"valid: 100%|██████████| 39/39 [00:00<00:00, 98.78it/s, loss=0.513, metrics={'acc': 0.754, 'prec': 0.2429}] \n",
596+
"epoch 2: 100%|██████████| 153/153 [00:01<00:00, 117.30it/s, loss=0.481, metrics={'acc': 0.782, 'prec': 0.8287}] \n",
597+
"valid: 100%|██████████| 39/39 [00:00<00:00, 106.49it/s, loss=0.454, metrics={'acc': 0.7866, 'prec': 0.8245}]\n",
598+
"epoch 3: 100%|██████████| 153/153 [00:01<00:00, 124.78it/s, loss=0.44, metrics={'acc': 0.8055, 'prec': 0.781}] \n",
599+
"valid: 100%|██████████| 39/39 [00:00<00:00, 115.36it/s, loss=0.425, metrics={'acc': 0.8077, 'prec': 0.7818}]\n",
600+
"epoch 4: 100%|██████████| 153/153 [00:01<00:00, 125.01it/s, loss=0.418, metrics={'acc': 0.814, 'prec': 0.7661}] \n",
601+
"valid: 100%|██████████| 39/39 [00:00<00:00, 114.92it/s, loss=0.408, metrics={'acc': 0.8149, 'prec': 0.7671}]\n",
602+
"epoch 5: 100%|██████████| 153/153 [00:01<00:00, 116.57it/s, loss=0.404, metrics={'acc': 0.819, 'prec': 0.7527}]\n",
603+
"valid: 100%|██████████| 39/39 [00:00<00:00, 108.89it/s, loss=0.397, metrics={'acc': 0.8203, 'prec': 0.7547}]\n"
604604
]
605605
}
606606
],

0 commit comments

Comments
 (0)