Skip to content

Commit 0c79deb

Browse files
authored
Merge pull request #47 from jrzaurin/saint
Saint
2 parents 8e110a9 + 8cd4944 commit 0c79deb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4549
-1671
lines changed

README.md

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ using wide and deep models.
2222

2323
**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
2424

25+
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
26+
2527
### Introduction
2628

2729
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
@@ -82,10 +84,11 @@ into:
8284

8385
It is important to emphasize that **each individual component, `wide`,
8486
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
85-
isolation. For example, one could use only `wide`, which is in simply a linear
86-
model. In fact, one of the most interesting functionalities
87+
isolation. For example, one could use only `wide`, which is in simply a
88+
linear model. In fact, one of the most interesting functionalities
8789
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
88-
``pytorch-widedeep`` offers 4 models for that component:
90+
``pytorch-widedeep`` offers the following different models for that
91+
component:
8992

9093
1. ``TabMlp``: this is almost identical to the [tabular
9194
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
@@ -100,11 +103,26 @@ passed through a series of ResNet blocks built with dense layers.
100103
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
101104

102105
4. ``TabTransformer``: Details on the TabTransformer can be found in:
103-
[TabTransformer: Tabular Data Modeling Using Contextual
104-
Embeddings](https://arxiv.org/pdf/2012.06678.pdf)
105-
106-
107-
For details on these 4 models and their options please see the examples in the
106+
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
107+
Note that the TabTransformer implementation available at ``pytorch-widedeep``
108+
is an adaptation of the original implementation.
109+
110+
5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small
111+
variation of the ``TabTransformer``. The variation itself was first
112+
introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first
113+
used in
114+
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
115+
When using the ``FT-Transformer`` each continuous feature is "embedded"
116+
(i.e. going through a 1-layer MLP with or without activation function) and
117+
then passed through the attention blocks along with the categorical features.
118+
This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the
119+
parameter ``embed_continuous = True``.
120+
121+
122+
6. ``SAINT``: Details on SAINT can be found in:
123+
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
124+
125+
For details on these models and their options please see the examples in the
108126
Examples folder and the documentation.
109127

110128
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
@@ -143,20 +161,19 @@ cd pytorch-widedeep
143161
pip install -e .
144162
```
145163

146-
**Important note for Mac users**: at the time of writing (June-2021) the
147-
latest `torch` release is `1.9`. Some past
148-
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
149-
when running on Mac, present in previous versions, persist on this release and
150-
the data-loaders will not run in parallel. In addition, since `python 3.8`,
151-
[the `multiprocessing` library start method changed from `'fork'` to
152-
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
164+
**Important note for Mac users**: at the time of writing the latest `torch`
165+
release is `1.9`. Some past [issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
166+
when running on Mac, present in previous versions, persist on this release
167+
and the data-loaders will not run in parallel. In addition, since `python
168+
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
153169
This also affects the data-loaders (for any `torch` version) and they will
154-
not run in parallel. Therefore, for Mac users I recommend using `python 3.6`
155-
or `3.7` and `torch <= 1.6` (with the corresponding, consistent version of
156-
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
157-
versioning in the `setup.py` file since I expect that all these issues are
158-
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip
159-
or directly from github, downgrade `torch` and `torchvision` manually:
170+
not run in parallel. Therefore, for Mac users I recommend using `python
171+
3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
172+
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
173+
force this versioning in the `setup.py` file since I expect that all these
174+
issues are fixed in the future. Therefore, after installing
175+
`pytorch-widedeep` via pip or directly from github, downgrade `torch` and
176+
`torchvision` manually:
160177

161178
```bash
162179
pip install pytorch-widedeep

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.0
1+
1.0.5

docs/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ them to address different problems
1515
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
1616
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
1717
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
18+
* `The Transformer Family <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/10_The_Transformer_Family.ipynb>`__
63.9 KB
Loading

docs/figures/saint_arch.png

70.6 KB
Loading

docs/figures/tabmlp_arch.png

1.75 KB
Loading

docs/figures/tabnet_arch_1.png

64.1 KB
Loading

docs/figures/tabnet_arch_2.png

72.7 KB
Loading

docs/figures/tabresnet_arch.png

2.5 KB
Loading
596 Bytes
Loading

0 commit comments

Comments
 (0)