Skip to content

Commit 86217bb

Browse files
authored
Merge pull request #49 from jrzaurin/jrzaurin/perceiver
Jrzaurin/perceiver
2 parents 0c79deb + e865fb2 commit 86217bb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+5103
-2692
lines changed

README.md

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
1010
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
1111
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
12-
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
12+
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
1313

1414
# pytorch-widedeep
1515

@@ -24,6 +24,13 @@ using wide and deep models.
2424

2525
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
2626

27+
The content of this document is organized as follows:
28+
29+
1. [introduction](#introduction)
30+
2. [The deeptabular component](#the-deeptabular-component)
31+
3. [installation](#installation)
32+
4. [quick start (tl;dr)](#quick-start)
33+
2734
### Introduction
2835

2936
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
@@ -82,61 +89,58 @@ into:
8289
<img width="300" src="docs/figures/architecture_2_math.png">
8390
</p>
8491

92+
I recommend using the ``wide`` and ``deeptabular`` models in
93+
``pytorch-widedeep``. However it is very likely that users will want to use
94+
their own models for the ``deeptext`` and ``deepimage`` components. That is
95+
perfectly possible as long as the the custom models have an attribute called
96+
``output_dim`` with the size of the last layer of activations, so that
97+
``WideDeep`` can be constructed. Again, examples on how to use custom
98+
components can be found in the Examples folder. Just in case
99+
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
100+
(pre-trained ResNets or stack of CNNs) models.
101+
102+
### The ``deeptabular`` component
103+
85104
It is important to emphasize that **each individual component, `wide`,
86105
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
87106
isolation. For example, one could use only `wide`, which is in simply a
88107
linear model. In fact, one of the most interesting functionalities
89-
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
90-
``pytorch-widedeep`` offers the following different models for that
91-
component:
108+
in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on
109+
its own, i.e. what one might normally refer as Deep Learning for Tabular
110+
Data. Currently, ``pytorch-widedeep`` offers the following different models
111+
for that component:
92112

93-
1. ``TabMlp``: this is almost identical to the [tabular
94-
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
95-
[fastai](https://docs.fast.ai/) library, and consists simply in embeddings
96-
representing the categorical features, concatenated with the continuous
97-
features, and passed then through a MLP.
98113

99-
2. ``TabRenset``: This is similar to the previous model but the embeddings are
114+
1. **TabMlp**: a simple MLP that receives embeddings representing the
115+
categorical features, concatenated with the continuous features.
116+
2. **TabResnet**: similar to the previous model but the embeddings are
100117
passed through a series of ResNet blocks built with dense layers.
101-
102-
3. ``Tabnet``: Details on TabNet can be found in:
118+
3. **TabNet**: details on TabNet can be found in
103119
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
104120

105-
4. ``TabTransformer``: Details on the TabTransformer can be found in:
106-
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
107-
Note that the TabTransformer implementation available at ``pytorch-widedeep``
108-
is an adaptation of the original implementation.
121+
And the ``Tabformer`` family, i.e. Transformers for Tabular data:
109122

110-
5. ``FT-Transformer``: or Feature Tokenizer transformer. This is a relatively small
111-
variation of the ``TabTransformer``. The variation itself was first
112-
introduced in the ``SAINT`` paper, but the name "``FT-Transformer``" was first
113-
used in
123+
4. **TabTransformer**: details on the TabTransformer can be found in
124+
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
125+
5. **SAINT**: Details on SAINT can be found in
126+
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
127+
6. **FT-Transformer**: details on the FT-Transformer can be found in
114128
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
115-
When using the ``FT-Transformer`` each continuous feature is "embedded"
116-
(i.e. going through a 1-layer MLP with or without activation function) and
117-
then passed through the attention blocks along with the categorical features.
118-
This is available in ``pytorch-widedeep``'s ``TabTransformer`` by setting the
119-
parameter ``embed_continuous = True``.
120-
129+
7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
130+
on the Fasformer can be found in
131+
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
132+
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
133+
the Perceiver can be found in
134+
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
121135

122-
6. ``SAINT``: Details on SAINT can be found in:
123-
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
136+
Note that while there are scientific publications for the TabTransformer,
137+
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
138+
adaptation of those algorithms for tabular data.
124139

125140
For details on these models and their options please see the examples in the
126141
Examples folder and the documentation.
127142

128-
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
129-
``pytorch-widedeep`` it is very likely that users will want to use their own
130-
models for the ``deeptext`` and ``deepimage`` components. That is perfectly
131-
possible as long as the the custom models have an attribute called
132-
``output_dim`` with the size of the last layer of activations, so that
133-
``WideDeep`` can be constructed. Again, examples on how to use custom
134-
components can be found in the Examples folder. Just in case
135-
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
136-
(pre-trained ResNets or stack of CNNs) models.
137-
138-
139-
### Installation
143+
### Installation
140144

141145
Install using pip:
142146

@@ -167,8 +171,8 @@ when running on Mac, present in previous versions, persist on this release
167171
and the data-loaders will not run in parallel. In addition, since `python
168172
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
169173
This also affects the data-loaders (for any `torch` version) and they will
170-
not run in parallel. Therefore, for Mac users I recommend using `python
171-
3.6` or `3.7` and `torch <= 1.6` (with the corresponding, consistent
174+
not run in parallel. Therefore, for Mac users I recommend using `python 3.7`
175+
and `torch <= 1.6` (with the corresponding, consistent
172176
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
173177
force this versioning in the `setup.py` file since I expect that all these
174178
issues are fixed in the future. Therefore, after installing

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0.5
1+
1.0.9

docs/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ them to address different problems
1616
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__
1717
* `Using Custom DataLoaders and Torchmetrics <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/09_Custom_DataLoader_Imbalanced_dataset.ipynb>`__
1818
* `The Transformer Family <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/10_The_Transformer_Family.ipynb>`__
19+
* `Extracting Embeddings <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/11_Extracting_Embeddings.ipynb>`__

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Documentation
2323
Dataloaders <dataloaders>
2424
Callbacks <callbacks>
2525
The Trainer <trainer>
26+
Tab2Vec <tab2vec>
2627
Examples <examples>
2728

2829

docs/losses.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ on their own and can be imported as:
1717
from pytorch_widedeep.losses import FocalLoss
1818
1919
.. note:: Losses in this module expect the predictions and ground truth to have the
20-
same dimensions for regression and binary classification problems (i.e.
21-
:math:`N_{samples}, 1)`. In the case of multiclass classification problems
20+
same dimensions for regression and binary classification problems
21+
:math:`(N_{samples}, 1)`. In the case of multiclass classification problems
2222
the ground truth is expected to be a 1D tensor with the corresponding
2323
classes. See Examples below
2424

docs/metrics.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ Metrics
22
=======
33

44
.. note:: Metrics in this module expect the predictions and ground truth to have the
5-
same dimensions for regression and binary classification problems (i.e.
6-
:math:`N_{samples}, 1)`. In the case of multiclass classification problems the
7-
ground truth is expected to be a 1D tensor with the corresponding classes.
8-
See Examples below
5+
same dimensions for regression and binary classification problems: :math:`(N_{samples}, 1)`.
6+
In the case of multiclass classification problems the ground truth is expected to be
7+
a 1D tensor with the corresponding classes. See Examples below
98

109
We have added the possibility of using the metrics available at the
1110
`torchmetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ library.

docs/model_components.rst

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ This module contains the four main components that will comprise a Wide and
55
Deep model, and the ``WideDeep`` "constructor" class. These four components
66
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
77

8-
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer`` and ``SAINT`` can
9-
all be used as the ``deeptabular`` component of the model and simply
10-
represent different alternatives
8+
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet``, ``TabTransformer``, ``SAINT``,
9+
``FTTransformer``, ``TabPerceiver`` and ``TabFastFormer`` can all be used
10+
as the ``deeptabular`` component of the model and simply represent different
11+
alternatives
1112

1213
.. autoclass:: pytorch_widedeep.models.wide.Wide
1314
:exclude-members: forward
@@ -33,6 +34,18 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
3334
:exclude-members: forward
3435
:members:
3536

37+
.. autoclass:: pytorch_widedeep.models.transformers.ft_transformer.FTTransformer
38+
:exclude-members: forward
39+
:members:
40+
41+
.. autoclass:: pytorch_widedeep.models.transformers.tab_perceiver.TabPerceiver
42+
:exclude-members: forward
43+
:members:
44+
45+
.. autoclass:: pytorch_widedeep.models.transformers.tab_fastformer.TabFastFormer
46+
:exclude-members: forward
47+
:members:
48+
3649
.. autoclass:: pytorch_widedeep.models.deep_text.DeepText
3750
:exclude-members: forward
3851
:members:

docs/tab2vec.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Tab2Vec
2+
=======
3+
4+
.. autoclass:: pytorch_widedeep.tab2vec.Tab2Vec
5+
:members:
6+
:undoc-members:
7+

0 commit comments

Comments
 (0)