Skip to content

Commit f430864

Browse files
authored
Merge pull request #33 from jrzaurin/tabtransformer
Tabtransformer
2 parents 2c53901 + 56bd75e commit f430864

File tree

110 files changed

+10830
-6219
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+10830
-6219
lines changed

README.md

Lines changed: 100 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,49 @@
1313

1414
# pytorch-widedeep
1515

16-
A flexible package to combine tabular data with text and images using wide and
17-
deep models.
16+
A flexible package to use Deep Learning with tabular data, text and images
17+
using wide and deep models.
1818

1919
**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)
2020

2121
**Companion posts:** [infinitoml](https://jrzaurin.github.io/infinitoml/)
2222

2323
### Introduction
2424

25-
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
26-
the original algorithm can be found
27-
[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the nice
28-
research paper can be found [here](https://arxiv.org/abs/1606.07792).
25+
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
26+
Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).
2927

3028
In general terms, `pytorch-widedeep` is a package to use deep learning with
3129
tabular data. In particular, is intended to facilitate the combination of text
3230
and images with corresponding tabular data using wide and deep models. With
33-
that in mind there are two architectures that can be implemented with just a
34-
few lines of code.
31+
that in mind there are a number of architectures that can be implemented with
32+
just a few lines of code. The main components of those architectures are shown
33+
in the Figure below:
3534

36-
### Architectures
37-
38-
**Architecture 1**:
3935

4036
<p align="center">
41-
<img width="750" src="docs/figures/architecture_1.png">
37+
<img width="750" src="docs/figures/widedeep_arch.png">
4238
</p>
4339

44-
Architecture 1 combines the `Wide`, Linear model with the outputs from the
45-
`DeepDense` or `DeepDenseResnet`, `DeepText` and `DeepImage` components
46-
connected to a final output neuron or neurons, depending on whether we are
47-
performing a binary classification or regression, or a multi-class
48-
classification. The components within the faded-pink rectangles are
49-
concatenated.
40+
The dashed boxes in the figure represent optional, overall components, and the
41+
dashed lines/arrows indicate the corresponding connections, depending on
42+
whether or not certain components are present. For example, the dashed,
43+
blue-lines indicate that the ``deeptabular``, ``deeptext`` and ``deepimage``
44+
components are connected directly to the output neuron or neurons (depending
45+
on whether we are performing a binary classification or regression, or a
46+
multi-class classification) if the optional ``deephead`` is not present.
47+
Finally, the components within the faded-pink rectangle are concatenated.
48+
49+
Note that it is not possible to illustrate the number of possible
50+
architectures and components available in ``pytorch-widedeep`` in one Figure.
51+
Therefore, for more details on possible architectures (and more) please, see
52+
the
53+
[documentation]((https://pytorch-widedeep.readthedocs.io/en/latest/index.html)),
54+
or the Examples folders and the notebooks there.
5055

5156
In math terms, and following the notation in the
52-
[paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated
53-
as:
57+
[paper](https://arxiv.org/abs/1606.07792), the expression for the architecture
58+
without a ``deephead`` component can be formulated as:
5459

5560
<p align="center">
5661
<img width="500" src="docs/figures/architecture_1_math.png">
@@ -67,43 +72,47 @@ the constituent features (“gender=female” and “language=en”) are all 1,
6772
otherwise".*
6873

6974

70-
**Architecture 2**
71-
72-
<p align="center">
73-
<img width="750" src="docs/figures/architecture_2.png">
74-
</p>
75-
76-
Architecture 2 combines the `Wide`, Linear model with the Deep components of
77-
the model connected to the output neuron(s), after the different Deep
78-
components have been themselves combined through a FC-Head (that I refer as
79-
`deephead`).
80-
81-
In math terms, and following the notation in the
82-
[paper](https://arxiv.org/abs/1606.07792), Architecture 2 can be formulated
83-
as:
75+
While if there is a ``deephead`` component, the previous expression turns
76+
into:
8477

8578
<p align="center">
8679
<img width="300" src="docs/figures/architecture_2_math.png">
8780
</p>
8881

89-
Note that each individual component, `wide`, `deepdense` (either `DeepDense`
90-
or `DeepDenseResnet`), `deeptext` and `deepimage`, can be used independently
91-
and in isolation. For example, one could use only `wide`, which is in simply a
92-
linear model.
93-
94-
On the other hand, while I recommend using the `Wide` and `DeepDense` (or
95-
`DeepDenseResnet`) classes in `pytorch-widedeep` to build the `wide` and
96-
`deepdense` component, it is very likely that users will want to use their own
97-
models in the case of the `deeptext` and `deepimage` components. That is
98-
perfectly possible as long as the the custom models have an attribute called
99-
`output_dim` with the size of the last layer of activations, so that
100-
`WideDeep` can be constructed
101-
102-
`pytorch-widedeep` includes standard text (stack of LSTMs) and image
82+
It is important to emphasize that **each individual component, `wide`,
83+
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
84+
isolation. For example, one could use only `wide`, which is in simply a linear
85+
model. In fact, one of the most interesting functionalities
86+
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
87+
``pytorch-widedeep`` offers 3 models for that component:
88+
89+
1. ``TabMlp``: this is almost identical to the [tabular
90+
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
91+
[fastai](https://docs.fast.ai/) library, and consists simply in embeddings
92+
representing the categorical features, concatenated with the continuous
93+
features, and passed then through a MLP.
94+
95+
2. ``TabRenset``: This is similar to the previous model but the embeddings are
96+
passed through a series of ResNet blocks built with dense layers.
97+
98+
3. ``TabTransformer``: Details on the TabTransformer can be found in:
99+
[TabTransformer: Tabular Data Modeling Using Contextual
100+
Embeddings](https://arxiv.org/pdf/2012.06678.pdf)
101+
102+
103+
For details on these 3 models and their options please see the examples in the
104+
Examples folder and the documentation.
105+
106+
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
107+
``pytorch-widedeep`` it is very likely that users will want to use their own
108+
models for the ``deeptext`` and ``deepimage`` components. That is perfectly
109+
possible as long as the the custom models have an attribute called
110+
``output_dim`` with the size of the last layer of activations, so that
111+
``WideDeep`` can be constructed. Again, examples on how to use custom
112+
components can be found in the Examples folder. Just in case
113+
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
103114
(pre-trained ResNets or stack of CNNs) models.
104115

105-
See the examples folder or the docs for more information.
106-
107116

108117
### Installation
109118

@@ -130,8 +139,8 @@ cd pytorch-widedeep
130139
pip install -e .
131140
```
132141

133-
**Important note for Mac users**: at the time of writing (Dec-2020) the latest
134-
`torch` release is `1.7`. This release has some
142+
**Important note for Mac users**: at the time of writing (Feb-2020) the latest
143+
`torch` release is `1.7.1`. This release has some
135144
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
136145
when running on Mac and the data-loaders will not run in parallel. In
137146
addition, since `python 3.8`, [the `multiprocessing` library start method
@@ -158,17 +167,26 @@ Binary classification with the [adult
158167
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
159168
using `Wide` and `DeepDense` and defaults settings.
160169

170+
171+
```python
172+
```
173+
174+
Building a wide (linear) and deep model with ``pytorch-widedeep``:
175+
161176
```python
177+
162178
import pandas as pd
163179
import numpy as np
164180
from sklearn.model_selection import train_test_split
165181

166-
from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
167-
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
182+
from pytorch_widedeep import Trainer
183+
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
184+
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
168185
from pytorch_widedeep.metrics import Accuracy
169186

170-
# these next 4 lines are not directly related to pytorch-widedeep. I assume
171-
# you have downloaded the dataset and place it in a dir called data/adult/
187+
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
188+
# assume you have downloaded the dataset and place it in a dir called
189+
# data/adult/
172190
df = pd.read_csv("data/adult/adult.csv.zip")
173191
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
174192
df.drop("income", axis=1, inplace=True)
@@ -197,61 +215,46 @@ target_col = "income_label"
197215
target = df_train[target_col].values
198216

199217
# wide
200-
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
201-
X_wide = preprocess_wide.fit_transform(df_train)
218+
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
219+
X_wide = wide_preprocessor.fit_transform(df_train)
202220
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
203221

204-
# deepdense
205-
preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
206-
X_deep = preprocess_deep.fit_transform(df_train)
207-
deepdense = DeepDense(
208-
hidden_layers=[64, 32],
209-
deep_column_idx=preprocess_deep.deep_column_idx,
210-
embed_input=preprocess_deep.embeddings_input,
222+
# deeptabular
223+
tab_preprocessor = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
224+
X_tab = tab_preprocessor.fit_transform(df_train)
225+
deeptabular = TabMlp(
226+
mlp_hidden_dims=[64, 32],
227+
column_idx=tab_preprocessor.column_idx,
228+
embed_input=tab_preprocessor.embeddings_input,
211229
continuous_cols=cont_cols,
212230
)
213-
# # To use DeepDenseResnet as the deepdense component simply:
214-
# from pytorch_widedeep.models import DeepDenseResnet:
215-
# deepdense = DeepDenseResnet(
216-
# blocks=[64, 32],
217-
# deep_column_idx=preprocess_deep.deep_column_idx,
218-
# embed_input=preprocess_deep.embeddings_input,
219-
# continuous_cols=cont_cols,
220-
# )
221-
222-
# build, compile and fit
223-
model = WideDeep(wide=wide, deepdense=deepdense)
224-
model.compile(method="binary", metrics=[Accuracy])
225-
model.fit(
231+
232+
# wide and deep
233+
model = WideDeep(wide=wide, deeptabular=deeptabular)
234+
235+
# train the model
236+
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
237+
trainer.fit(
226238
X_wide=X_wide,
227-
X_deep=X_deep,
239+
X_tab=X_tab,
228240
target=target,
229241
n_epochs=5,
230242
batch_size=256,
231243
val_split=0.1,
232244
)
233245

234246
# predict
235-
X_wide_te = preprocess_wide.transform(df_test)
236-
X_deep_te = preprocess_deep.transform(df_test)
237-
preds = model.predict(X_wide=X_wide_te, X_deep=X_deep_te)
238-
239-
#  # save and load
240-
# torch.save(model, "model_weights/model.t")
241-
# model = torch.load("model_weights/model.t")
242-
243-
# # or via state dictionaries
244-
# torch.save(model.state_dict(), PATH)
245-
# model = WideDeep(*args)
246-
# model.load_state_dict(torch.load(PATH))
247+
X_wide_te = wide_preprocessor.transform(df_test)
248+
X_tab_te = tab_preprocessor.transform(df_test)
249+
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
250+
251+
# save and load
252+
trainer.save_model("model_weights/model.t")
247253
```
248254

249-
Of course, one can do much more, such as using different initializations,
250-
optimizers or learning rate schedulers for each component of the overall
251-
model. Adding FC-Heads to the Text and Image components. Using the [Focal
252-
Loss](https://arxiv.org/abs/1708.02002), warming up individual components
253-
before joined training, etc. See the `examples` or the `docs` folders for a
254-
better understanding of the content of the package and its functionalities.
255+
Of course, one can do **much more**. See the Examples folder, the
256+
documentation or the companion posts for a better understanding of the content
257+
of the package and its functionalities.
255258

256259
### Testing
257260

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.7
1+
0.4.8

docs/_static/custom.css

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ div.ethical-rtd {
3939
.wy-nav-content {
4040
max-width: none; !important;
4141
}
42+
43+
div.container a.header-logo {
44+
background-image: url("../figures/widedeep_logo.png");
45+
}
99.7 KB
Binary file not shown.

docs/callbacks.rst

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
Callbacks
22
=========
33

4+
Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``,
5+
``LRHistory``, ``ModelCheckpoint`` and ``EarlyStopping``.
6+
7+
.. note:: ``History`` runs by default, so it should not be passed
8+
to the ``Trainer``
9+
410
.. autoclass:: pytorch_widedeep.callbacks.History
511
:members:
6-
:undoc-members:
7-
:show-inheritance:
812

913
.. autoclass:: pytorch_widedeep.callbacks.LRHistory
1014
:members:
11-
:undoc-members:
12-
:show-inheritance:
1315

1416
.. autoclass:: pytorch_widedeep.callbacks.ModelCheckpoint
1517
:members:
16-
:undoc-members:
17-
:show-inheritance:
1818

1919
.. autoclass:: pytorch_widedeep.callbacks.EarlyStopping
2020
:members:
21-
:undoc-members:
22-
:show-inheritance:

0 commit comments

Comments
 (0)