Skip to content

Commit 2fe4b49

Browse files
authored
Merge pull request #30 from jrzaurin/fix_image_format
Fix image format
2 parents 9f61051 + 3913afe commit 2fe4b49

File tree

15 files changed

+704
-237
lines changed

15 files changed

+704
-237
lines changed

README.md

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
<p align="center">
3-
<img width="450" src="docs/figures/widedeep_logo.png">
3+
<img width="300" src="docs/figures/widedeep_logo.png">
44
</p>
55

66
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
@@ -9,11 +9,7 @@
99
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
1010
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
1111
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
12-
13-
Platform | Version Support
14-
---------|:---------------
15-
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
16-
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
12+
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
1713

1814
# pytorch-widedeep
1915

@@ -88,15 +84,23 @@ as:
8884
<img width="300" src="docs/figures/architecture_2_math.png">
8985
</p>
9086

91-
When using `pytorch-widedeep`, the assumption is that the so called `Wide` and
92-
`deep dense` (this can be either `DeepDense` or `DeepDenseResnet`. See the
93-
documentation and examples folder for more details) components in the figures
94-
are **always** present, while `DeepText text` and `DeepImage` are optional.
87+
Note that each individual component, `wide`, `deepdense` (either `DeepDense`
88+
or `DeepDenseResnet`), `deeptext` and `deepimage`, can be used independently
89+
and in isolation. For example, one could use only `wide`, which is in simply a
90+
linear model.
91+
92+
On the other hand, while I recommend using the `Wide` and `DeepDense` (or
93+
`DeepDenseResnet`) classes in `pytorch-widedeep` to build the `wide` and
94+
`deepdense` component, it is very likely that users will want to use their own
95+
models in the case of the `deeptext` and `deepimage` components. That is
96+
perfectly possible as long as the the custom models have an attribute called
97+
`output_dim` with the size of the last layer of activations, so that
98+
`WideDeep` can be constructed
99+
95100
`pytorch-widedeep` includes standard text (stack of LSTMs) and image
96-
(pre-trained ResNets or stack of CNNs) models. However, the user can use any
97-
custom model as long as it has an attribute called `output_dim` with the size
98-
of the last layer of activations, so that `WideDeep` can be constructed. See
99-
the examples folder or the docs for more information.
101+
(pre-trained ResNets or stack of CNNs) models.
102+
103+
See the examples folder or the docs for more information.
100104

101105

102106
### Installation
@@ -124,6 +128,28 @@ cd pytorch-widedeep
124128
pip install -e .
125129
```
126130

131+
**Important note for Mac users**: at the time of writing (Dec-2020) the latest
132+
`torch` release is `1.7`. This release has some
133+
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
134+
when running on Mac and the data-loaders will not run in parallel. In
135+
addition, since `python 3.8`, [the `multiprocessing` library start method
136+
changed from `'fork'` to
137+
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
138+
This also affects the data-loaders (for any `torch` version) and they will not
139+
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
140+
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
141+
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
142+
versioning in the `setup.py` file since I expect that all these issues are
143+
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
144+
directly from github, downgrade `torch` and `torchvision` manually:
145+
146+
```bash
147+
pip install pytorch-widedeep
148+
pip install torch==1.6.0 torchvision==0.7.0
149+
```
150+
151+
None of these issues affect Linux users.
152+
127153
### Quick start
128154

129155
Binary classification with the [adult

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.6
1+
0.4.7

docs/figures/widedeep_logo.png

-34.8 KB
Loading

docs/figures/widedeep_logo_old.png

72.8 KB
Loading

examples/02_Model_Components.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
"cell_type": "markdown",
131131
"metadata": {},
132132
"source": [
133-
"if we simply numerically encode (label encode or `le`) the values, starting from 1 (we will save 0 for padding, i.e. unseen values)"
133+
"if we simply numerically encode (label encode or `le`) the values:"
134134
]
135135
},
136136
{
@@ -146,7 +146,9 @@
146146
"cell_type": "markdown",
147147
"metadata": {},
148148
"source": [
149-
"now, let's see if the two implementations are equivalent"
149+
"Note that in the functioning implementation of the package we start from 1, saving 0 for padding, i.e. unseen values. \n",
150+
"\n",
151+
"Now, let's see if the two implementations are equivalent"
150152
]
151153
},
152154
{
@@ -261,7 +263,7 @@
261263
"cell_type": "markdown",
262264
"metadata": {},
263265
"source": [
264-
"Note that even though the input dim is 10, the Embedding layer has 11 weights. This is because we save 0 for padding, which is used for unseen values during the encoding process"
266+
"Note that even though the input dim is 10, the Embedding layer has 11 weights. Again, this is because we save 0 for padding, which is used for unseen values during the encoding process"
265267
]
266268
},
267269
{

examples/03_Binary_Classification_with_Defaults.ipynb

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,16 @@
591591
"name": "stderr",
592592
"output_type": "stream",
593593
"text": [
594-
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 115.33it/s, loss=0.743, metrics={'acc': 0.6205, 'prec': 0.2817}]\n",
595-
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.06it/s, loss=0.545, metrics={'acc': 0.6452, 'prec': 0.3014}]\n",
596-
"epoch 2: 100%|██████████| 611/611 [00:04<00:00, 122.57it/s, loss=0.486, metrics={'acc': 0.7765, 'prec': 0.5517}]\n",
597-
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.84it/s, loss=0.44, metrics={'acc': 0.783, 'prec': 0.573}] \n",
598-
"epoch 3: 100%|██████████| 611/611 [00:04<00:00, 124.89it/s, loss=0.419, metrics={'acc': 0.8129, 'prec': 0.6753}]\n",
599-
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.10it/s, loss=0.402, metrics={'acc': 0.815, 'prec': 0.6816}] \n",
600-
"epoch 4: 100%|██████████| 611/611 [00:04<00:00, 126.35it/s, loss=0.393, metrics={'acc': 0.8228, 'prec': 0.7047}]\n",
601-
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.72it/s, loss=0.385, metrics={'acc': 0.8233, 'prec': 0.7024}]\n",
602-
"epoch 5: 100%|██████████| 611/611 [00:04<00:00, 124.33it/s, loss=0.38, metrics={'acc': 0.826, 'prec': 0.702}] \n",
603-
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.43it/s, loss=0.376, metrics={'acc': 0.8264, 'prec': 0.7}] \n"
594+
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 101.71it/s, loss=0.448, metrics={'acc': 0.792, 'prec': 0.5728}] \n",
595+
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.00it/s, loss=0.366, metrics={'acc': 0.7991, 'prec': 0.5907}]\n",
596+
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 101.69it/s, loss=0.361, metrics={'acc': 0.8324, 'prec': 0.6817}]\n",
597+
"valid: 100%|██████████| 153/153 [00:00<00:00, 169.36it/s, loss=0.357, metrics={'acc': 0.8328, 'prec': 0.6807}]\n",
598+
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 102.65it/s, loss=0.352, metrics={'acc': 0.8366, 'prec': 0.691}] \n",
599+
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.49it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6867}]\n",
600+
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 101.52it/s, loss=0.347, metrics={'acc': 0.8389, 'prec': 0.6956}]\n",
601+
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.49it/s, loss=0.349, metrics={'acc': 0.8383, 'prec': 0.6906}]\n",
602+
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 84.91it/s, loss=0.343, metrics={'acc': 0.8405, 'prec': 0.6987}] \n",
603+
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.83it/s, loss=0.347, metrics={'acc': 0.8399, 'prec': 0.6946}]\n"
604604
]
605605
}
606606
],
@@ -664,22 +664,88 @@
664664
"name": "stderr",
665665
"output_type": "stream",
666666
"text": [
667-
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 108.62it/s, loss=0.894, metrics={'acc': 0.5182, 'prec': 0.2037}]\n",
668-
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.44it/s, loss=0.604, metrics={'acc': 0.5542, 'prec': 0.2135}]\n",
669-
"epoch 2: 100%|██████████| 611/611 [00:05<00:00, 106.49it/s, loss=0.51, metrics={'acc': 0.751, 'prec': 0.4614}] \n",
670-
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.79it/s, loss=0.452, metrics={'acc': 0.7581, 'prec': 0.4898}]\n",
671-
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 106.66it/s, loss=0.425, metrics={'acc': 0.8031, 'prec': 0.6618}]\n",
672-
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.73it/s, loss=0.405, metrics={'acc': 0.806, 'prec': 0.6686}] \n",
673-
"epoch 4: 100%|██████████| 611/611 [00:05<00:00, 106.58it/s, loss=0.394, metrics={'acc': 0.8185, 'prec': 0.6966}]\n",
674-
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.55it/s, loss=0.385, metrics={'acc': 0.8196, 'prec': 0.6994}]\n",
675-
"epoch 5: 100%|██████████| 611/611 [00:05<00:00, 107.28it/s, loss=0.38, metrics={'acc': 0.8236, 'prec': 0.7004}] \n",
676-
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.37it/s, loss=0.375, metrics={'acc': 0.8244, 'prec': 0.7017}]\n"
667+
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 77.46it/s, loss=0.387, metrics={'acc': 0.8192, 'prec': 0.6576}]\n",
668+
"valid: 100%|██████████| 153/153 [00:01<00:00, 147.78it/s, loss=0.36, metrics={'acc': 0.8216, 'prec': 0.6617}] \n",
669+
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 74.99it/s, loss=0.358, metrics={'acc': 0.8313, 'prec': 0.6836}]\n",
670+
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.26it/s, loss=0.355, metrics={'acc': 0.8321, 'prec': 0.6848}]\n",
671+
"epoch 3: 100%|██████████| 611/611 [00:08<00:00, 76.28it/s, loss=0.351, metrics={'acc': 0.8345, 'prec': 0.6889}]\n",
672+
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.84it/s, loss=0.354, metrics={'acc': 0.8347, 'prec': 0.6887}]\n",
673+
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 76.71it/s, loss=0.346, metrics={'acc': 0.8374, 'prec': 0.6946}]\n",
674+
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.80it/s, loss=0.353, metrics={'acc': 0.8369, 'prec': 0.6935}]\n",
675+
"epoch 5: 100%|██████████| 611/611 [00:08<00:00, 73.25it/s, loss=0.343, metrics={'acc': 0.8386, 'prec': 0.6966}]\n",
676+
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.05it/s, loss=0.352, metrics={'acc': 0.8382, 'prec': 0.6961}]\n"
677677
]
678678
}
679679
],
680680
"source": [
681681
"model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
682682
]
683+
},
684+
{
685+
"cell_type": "markdown",
686+
"metadata": {},
687+
"source": [
688+
"Also mentioning that one could build a model with the individual components independently. For example, a model comprised only by the `wide` component would be simply a linear model. This could be attained by just:"
689+
]
690+
},
691+
{
692+
"cell_type": "code",
693+
"execution_count": 15,
694+
"metadata": {},
695+
"outputs": [],
696+
"source": [
697+
"model = WideDeep(wide=wide)"
698+
]
699+
},
700+
{
701+
"cell_type": "code",
702+
"execution_count": 16,
703+
"metadata": {},
704+
"outputs": [],
705+
"source": [
706+
"model.compile(method='binary', metrics=[Accuracy, Precision])"
707+
]
708+
},
709+
{
710+
"cell_type": "code",
711+
"execution_count": 17,
712+
"metadata": {},
713+
"outputs": [
714+
{
715+
"name": "stderr",
716+
"output_type": "stream",
717+
"text": [
718+
"\r",
719+
" 0%| | 0/611 [00:00<?, ?it/s]"
720+
]
721+
},
722+
{
723+
"name": "stdout",
724+
"output_type": "stream",
725+
"text": [
726+
"Training\n"
727+
]
728+
},
729+
{
730+
"name": "stderr",
731+
"output_type": "stream",
732+
"text": [
733+
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 188.59it/s, loss=0.482, metrics={'acc': 0.771, 'prec': 0.5633}] \n",
734+
"valid: 100%|██████████| 153/153 [00:00<00:00, 236.13it/s, loss=0.423, metrics={'acc': 0.7747, 'prec': 0.5819}]\n",
735+
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 190.62it/s, loss=0.399, metrics={'acc': 0.8131, 'prec': 0.686}] \n",
736+
"valid: 100%|██████████| 153/153 [00:00<00:00, 221.47it/s, loss=0.387, metrics={'acc': 0.8138, 'prec': 0.6879}]\n",
737+
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 190.28it/s, loss=0.378, metrics={'acc': 0.8267, 'prec': 0.7149}]\n",
738+
"valid: 100%|██████████| 153/153 [00:00<00:00, 241.12it/s, loss=0.374, metrics={'acc': 0.8255, 'prec': 0.7128}]\n",
739+
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 183.27it/s, loss=0.37, metrics={'acc': 0.8304, 'prec': 0.7073}] \n",
740+
"valid: 100%|██████████| 153/153 [00:00<00:00, 227.46it/s, loss=0.369, metrics={'acc': 0.8294, 'prec': 0.7061}]\n",
741+
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 184.28it/s, loss=0.366, metrics={'acc': 0.8315, 'prec': 0.7006}]\n",
742+
"valid: 100%|██████████| 153/153 [00:00<00:00, 239.87it/s, loss=0.366, metrics={'acc': 0.8303, 'prec': 0.6999}]\n"
743+
]
744+
}
745+
],
746+
"source": [
747+
"model.fit(X_wide=X_wide, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
748+
]
683749
}
684750
],
685751
"metadata": {

pypi_README.md

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
55
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
66
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
7-
8-
Platform | Version Support
9-
---------|:---------------
10-
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
11-
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
7+
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
128

139
# pytorch-widedeep
1410

@@ -57,6 +53,28 @@ cd pytorch-widedeep
5753
pip install -e .
5854
```
5955

56+
**Important note for Mac users**: at the time of writing (Dec-2020) the latest
57+
`torch` release is `1.7`. This release has some
58+
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
59+
when running on Mac and the data-loaders will not run in parallel. In
60+
addition, since `python 3.8`, [the `multiprocessing` library start method
61+
changed from `'fork'` to
62+
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
63+
This also affects the data-loaders (for any `torch` version) and they will not
64+
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
65+
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
66+
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
67+
versioning in the `setup.py` file since I expect that all these issues are
68+
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
69+
directly from github, downgrade `torch` and `torchvision` manually:
70+
71+
```bash
72+
pip install pytorch-widedeep
73+
pip install torch==1.6.0 torchvision==0.7.0
74+
```
75+
76+
None of these issues affect Linux users.
77+
6078
### Quick start
6179

6280
Binary classification with the [adult

pytorch_widedeep/models/_wd_dataset.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ class WideDeepDataset(Dataset):
2727

2828
def __init__(
2929
self,
30-
X_wide: np.ndarray,
31-
X_deep: np.ndarray,
32-
target: Optional[np.ndarray] = None,
30+
X_wide: Optional[np.ndarray] = None,
31+
X_deep: Optional[np.ndarray] = None,
3332
X_text: Optional[np.ndarray] = None,
3433
X_img: Optional[np.ndarray] = None,
34+
target: Optional[np.ndarray] = None,
3535
transforms: Optional[Any] = None,
3636
):
3737

@@ -48,10 +48,12 @@ def __init__(
4848
self.transforms_names = []
4949
self.Y = target
5050

51-
def __getitem__(self, idx: int):
52-
# X_wide and X_deep are assumed to be *always* present
53-
X = Bunch(wide=self.X_wide[idx])
54-
X.deepdense = self.X_deep[idx]
51+
def __getitem__(self, idx: int): # noqa: C901
52+
X = Bunch()
53+
if self.X_wide is not None:
54+
X.wide = self.X_wide[idx]
55+
if self.X_deep is not None:
56+
X.deepdense = self.X_deep[idx]
5557
if self.X_text is not None:
5658
X.deeptext = self.X_text[idx]
5759
if self.X_img is not None:
@@ -68,6 +70,8 @@ def __getitem__(self, idx: int):
6870
# then we need to replicate what Tensor() does -> transpose axis
6971
# and normalize if necessary
7072
if not self.transforms or "ToTensor" not in self.transforms_names:
73+
if xdi.ndim == 2:
74+
xdi = xdi[:, :, None]
7175
xdi = xdi.transpose(2, 0, 1)
7276
if "int" in str(xdi.dtype):
7377
xdi = (xdi / xdi.max()).astype("float32")
@@ -87,4 +91,11 @@ def __getitem__(self, idx: int):
8791
return X
8892

8993
def __len__(self):
90-
return len(self.X_deep)
94+
if self.X_wide is not None:
95+
return len(self.X_wide)
96+
if self.X_deep is not None:
97+
return len(self.X_deep)
98+
if self.X_text is not None:
99+
return len(self.X_text)
100+
if self.X_img is not None:
101+
return len(self.X_img)

0 commit comments

Comments
 (0)