Skip to content

Commit 108cebc

Browse files
committed
updated README and docs
1 parent 23b9331 commit 108cebc

File tree

4 files changed

+159
-177
lines changed

4 files changed

+159
-177
lines changed

README.md

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
# pytorch-widedeep
1717

18-
A flexible package to use Deep Learning with tabular data, text and images
19-
using wide and deep models.
18+
A flexible package for multimodal-deep-learning to combine tabular data with
19+
text and images using Wide and Deep models in Pytorch
2020

2121
**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)
2222

2323
**Companion posts and tutorials:** [infinitoml](https://jrzaurin.github.io/infinitoml/)
2424

25-
**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
25+
**Experiments and comparison with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
2626

2727
The content of this document is organized as follows:
2828

@@ -33,7 +33,8 @@ The content of this document is organized as follows:
3333

3434
### Introduction
3535

36-
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
36+
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792),
37+
adjusted for multi-modal datasets
3738

3839
In general terms, `pytorch-widedeep` is a package to use deep learning with
3940
tabular data. In particular, is intended to facilitate the combination of text
@@ -89,15 +90,11 @@ into:
8990
<img width="300" src="docs/figures/architecture_2_math.png">
9091
</p>
9192

92-
I recommend using the ``wide`` and ``deeptabular`` models in
93-
``pytorch-widedeep``. However it is very likely that users will want to use
94-
their own models for the ``deeptext`` and ``deepimage`` components. That is
95-
perfectly possible as long as the the custom models have an attribute called
93+
It is perfectly possible to use custom models (and not necessarily those in
94+
the library) as long as the the custom models have an attribute called
9695
``output_dim`` with the size of the last layer of activations, so that
97-
``WideDeep`` can be constructed. Again, examples on how to use custom
98-
components can be found in the Examples folder. Just in case
99-
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
100-
(pre-trained ResNets or stack of CNNs) models.
96+
``WideDeep`` can be constructed. Examples on how to use custom components can
97+
be found in the Examples folder.
10198

10299
### The ``deeptabular`` component
103100

@@ -110,15 +107,17 @@ its own, i.e. what one might normally refer as Deep Learning for Tabular
110107
Data. Currently, ``pytorch-widedeep`` offers the following different models
111108
for that component:
112109

113-
110+
0. **Wide**: a simple linear model where the nonlinearities are captured via
111+
cross-product transformations, as explained before.
114112
1. **TabMlp**: a simple MLP that receives embeddings representing the
115-
categorical features, concatenated with the continuous features.
113+
categorical features, concatenated with the continuous features, which can
114+
also be embedded.
116115
2. **TabResnet**: similar to the previous model but the embeddings are
117116
passed through a series of ResNet blocks built with dense layers.
118117
3. **TabNet**: details on TabNet can be found in
119118
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
120119

121-
And the ``Tabformer`` family, i.e. Transformers for Tabular data:
120+
The ``Tabformer`` family, i.e. Transformers for Tabular data:
122121

123122
4. **TabTransformer**: details on the TabTransformer can be found in
124123
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
@@ -133,12 +132,19 @@ on the Fasformer can be found in
133132
the Perceiver can be found in
134133
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
135134

135+
And probabilistic DL models for tabular data based on
136+
[Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424):
137+
138+
9. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
139+
10. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
140+
136141
Note that while there are scientific publications for the TabTransformer,
137142
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
138143
adaptation of those algorithms for tabular data.
139144

140-
For details on these models and their options please see the examples in the
141-
Examples folder and the documentation.
145+
For details on these models (and all the other models in the library for the
146+
different data modes) and their corresponding options please see the examples
147+
in the Examples folder and the documentation.
142148

143149
### Installation
144150

@@ -165,13 +171,6 @@ cd pytorch-widedeep
165171
pip install -e .
166172
```
167173

168-
**Important note for Mac users**: Since `python
169-
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) which affects the data-loaders.
170-
For the time being, `pytorch-widedeep` sets the `num_workers` to 0 when using
171-
Mac and python version 3.8+.
172-
173-
Note that this issue does not affect Linux users.
174-
175174
### Quick start
176175

177176
Binary classification with the [adult
@@ -181,7 +180,6 @@ using `Wide` and `DeepDense` and defaults settings.
181180
Building a wide (linear) and deep model with ``pytorch-widedeep``:
182181

183182
```python
184-
185183
import pandas as pd
186184
import numpy as np
187185
import torch
@@ -191,16 +189,15 @@ from pytorch_widedeep import Trainer
191189
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
192190
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
193191
from pytorch_widedeep.metrics import Accuracy
192+
from pytorch_widedeep.datasets import load_adult
193+
194194

195-
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
196-
# assume you have downloaded the dataset and place it in a dir called
197-
# data/adult/
198-
df = pd.read_csv("data/adult/adult.csv.zip")
195+
df = load_adult(as_frame=True)
199196
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
200197
df.drop("income", axis=1, inplace=True)
201198
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
202199

203-
# prepare wide, crossed, embedding and continuous columns
200+
# Define the 'column set up'
204201
wide_cols = [
205202
"education",
206203
"relationship",
@@ -209,49 +206,53 @@ wide_cols = [
209206
"native-country",
210207
"gender",
211208
]
212-
cross_cols = [("education", "occupation"), ("native-country", "occupation")]
213-
embed_cols = [
214-
("education", 16),
215-
("workclass", 16),
216-
("occupation", 16),
217-
("native-country", 32),
218-
]
219-
cont_cols = ["age", "hours-per-week"]
220-
target_col = "income_label"
209+
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]
221210

222-
# target
223-
target = df_train[target_col].values
211+
cat_embed_cols = [
212+
"workclass",
213+
"education",
214+
"marital-status",
215+
"occupation",
216+
"relationship",
217+
"race",
218+
"gender",
219+
"capital-gain",
220+
"capital-loss",
221+
"native-country",
222+
]
223+
continuous_cols = ["age", "hours-per-week"]
224+
target = "income_label"
225+
target = df_train[target].values
224226

225-
# wide
226-
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
227+
# prepare the data
228+
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
227229
X_wide = wide_preprocessor.fit_transform(df_train)
228-
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
229230

230-
# deeptabular
231-
tab_preprocessor = TabPreprocessor(cat_embed_cols=embed_cols, continuous_cols=cont_cols)
231+
tab_preprocessor = TabPreprocessor(
232+
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type]
233+
)
232234
X_tab = tab_preprocessor.fit_transform(df_train)
233-
deeptabular = TabMlp(
234-
mlp_hidden_dims=[64, 32],
235+
236+
# build the model
237+
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
238+
tab_mlp = TabMlp(
235239
column_idx=tab_preprocessor.column_idx,
236-
embed_input=tab_preprocessor.cat_embed_input,
237-
continuous_cols=cont_cols,
240+
cat_embed_input=tab_preprocessor.cat_embed_input,
241+
continuous_cols=continuous_cols,
238242
)
243+
model = WideDeep(wide=wide, deeptabular=tab_mlp)
239244

240-
# wide and deep
241-
model = WideDeep(wide=wide, deeptabular=deeptabular)
242-
243-
# train the model
245+
# train and validate
244246
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
245247
trainer.fit(
246248
X_wide=X_wide,
247249
X_tab=X_tab,
248250
target=target,
249251
n_epochs=5,
250252
batch_size=256,
251-
val_split=0.1,
252253
)
253254

254-
# predict
255+
# predict on test
255256
X_wide_te = wide_preprocessor.transform(df_test)
256257
X_tab_te = tab_preprocessor.transform(df_test)
257258
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
@@ -268,14 +269,11 @@ torch.save(model.state_dict(), "model_weights/wd_model.pt")
268269
# From here in advance, Option 1 or 2 are the same. I assume the user has
269270
# prepared the data and defined the new model components:
270271
# 1. Build the model
271-
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
272+
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
272273
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
273274

274275
# 2. Instantiate the trainer
275-
trainer_new = Trainer(
276-
model_new,
277-
objective="binary",
278-
)
276+
trainer_new = Trainer(model_new, objective="binary")
279277

280278
# 3. Either start the fit or directly predict
281279
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)

docs/index.rst

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ Documentation
3131
Introduction
3232
------------
3333
``pytorch-widedeep`` is based on Google's `Wide and Deep Algorithm
34-
<https://arxiv.org/abs/1606.07792>`_.
34+
<https://arxiv.org/abs/1606.07792>`_, adjusted for multi-modal datasets
35+
3536

3637
In general terms, ``pytorch-widedeep`` is a package to use deep learning with
3738
tabular and multimodal data. In particular, is intended to facilitate the
@@ -97,17 +98,20 @@ own, i.e. what one might normally refer as Deep Learning for Tabular Data.
9798
Currently, ``pytorch-widedeep`` offers the following different models for
9899
that component:
99100

101+
0. **Wide**: a simple linear model where the nonlinearities are captured via
102+
cross-product transformations, as explained before.
100103

101104
1. **TabMlp**: a simple MLP that receives embeddings representing the
102-
categorical features, concatenated with the continuous features.
105+
categorical features, concatenated with the continuous features, which can
106+
also be embedded.
103107

104108
2. **TabResnet**: similar to the previous model but the embeddings are
105109
passed through a series of ResNet blocks built with dense layers.
106110

107111
3. **TabNet**: details on TabNet can be found in `TabNet: Attentive
108112
Interpretable Tabular Learning <https://arxiv.org/abs/1908.07442>`_
109113

110-
And the ``Tabformer`` family, i.e. Transformers for Tabular data:
114+
The ``Tabformer`` family, i.e. Transformers for Tabular data:
111115

112116
4. **TabTransformer**: details on the TabTransformer can be found in
113117
`TabTransformer: Tabular Data Modeling Using Contextual Embeddings
@@ -130,22 +134,24 @@ Models for Natural Language Understanding
130134
the Perceiver can be found in `Perceiver: General Perception with Iterative
131135
Attention <https://arxiv.org/abs/2103.03206>`_
132136

137+
And probabilistic DL models for tabular data based on
138+
`Weight Uncertainty in Neural Networks <https://arxiv.org/abs/1505.05424>`_:
139+
140+
9. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
141+
142+
10. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
143+
133144
Note that while there are scientific publications for the TabTransformer,
134145
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
135-
adaptation of those algorithms for tabular data.
136-
137-
For details on these models and their options please see the examples in the
138-
Examples folder and the documentation.
139-
140-
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
141-
``pytorch-widedeep`` it is very likely that users will want to use their own
142-
models for the ``deeptext`` and ``deepimage`` components. That is perfectly
143-
possible as long as the the custom models have an attribute called
144-
``output_dim`` with the size of the last layer of activations, so that
145-
``WideDeep`` can be constructed. Again, examples on how to use custom
146-
components can be found in the Examples folder. Just in case
147-
``pytorch-widedeep`` includes standard text (stack of LSTMs or GRUs) and
148-
image(pre-trained ResNets or stack of CNNs) models.
146+
adaptation of those algorithms for tabular data. For details on these models
147+
and their options please see the examples in the Examples folder and the
148+
documentation.
149+
150+
Finally, it is perfectly possible to use custom models as long as the the
151+
custom models have an attribute called ``output_dim`` with the size of the
152+
last layer of activations, so that ``WideDeep`` can be constructed. Again,
153+
examples on how to use custom components can be found in the Examples
154+
folder.
149155

150156
Indices and tables
151157
==================

0 commit comments

Comments
 (0)