Skip to content

Commit a71699d

Browse files
authored
Merge pull request #39 from jrzaurin/tabnet
Tabnet
2 parents b487b06 + bc873a0 commit a71699d

File tree

71 files changed

+6338
-2864
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+6338
-2864
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Untitled*.ipynb
1313
# data related dirs
1414
data/
1515
model_weights/
16+
tmp_dir/
1617
weights/
1718

1819
# Unit Tests/Coverage

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
dist: xenial
22
language: python
33
python:
4-
- "3.6"
5-
- "3.7"
4+
- "3.7.9"
65
- "3.8"
6+
- "3.9"
77
matrix:
88
fast_finish: true
99
include:

README.md

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
1010
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
1111
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
12-
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
12+
[![Python 3.6 3.7 3.8 3.9](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://www.python.org/)
1313

1414
# pytorch-widedeep
1515

@@ -18,12 +18,13 @@ using wide and deep models.
1818

1919
**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)
2020

21-
**Companion posts:** [infinitoml](https://jrzaurin.github.io/infinitoml/)
21+
**Companion posts and tutorials:** [infinitoml](https://jrzaurin.github.io/infinitoml/)
22+
23+
**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
2224

2325
### Introduction
2426

25-
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
26-
Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).
27+
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
2728

2829
In general terms, `pytorch-widedeep` is a package to use deep learning with
2930
tabular data. In particular, is intended to facilitate the combination of text
@@ -84,7 +85,7 @@ It is important to emphasize that **each individual component, `wide`,
8485
isolation. For example, one could use only `wide`, which is in simply a linear
8586
model. In fact, one of the most interesting functionalities
8687
in``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
87-
``pytorch-widedeep`` offers 3 models for that component:
88+
``pytorch-widedeep`` offers 4 models for that component:
8889

8990
1. ``TabMlp``: this is almost identical to the [tabular
9091
model](https://docs.fast.ai/tutorial.tabular.html) in the fantastic
@@ -95,12 +96,15 @@ features, and passed then through a MLP.
9596
2. ``TabRenset``: This is similar to the previous model but the embeddings are
9697
passed through a series of ResNet blocks built with dense layers.
9798

98-
3. ``TabTransformer``: Details on the TabTransformer can be found in:
99+
3. ``Tabnet``: Details on TabNet can be found in:
100+
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
101+
102+
4. ``TabTransformer``: Details on the TabTransformer can be found in:
99103
[TabTransformer: Tabular Data Modeling Using Contextual
100104
Embeddings](https://arxiv.org/pdf/2012.06678.pdf)
101105

102106

103-
For details on these 3 models and their options please see the examples in the
107+
For details on these 4 models and their options please see the examples in the
104108
Examples folder and the documentation.
105109

106110
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in
@@ -139,20 +143,20 @@ cd pytorch-widedeep
139143
pip install -e .
140144
```
141145

142-
**Important note for Mac users**: at the time of writing (Feb-2020) the latest
143-
`torch` release is `1.7.1`. This release has some
146+
**Important note for Mac users**: at the time of writing (June-2021) the
147+
latest `torch` release is `1.9`. Some past
144148
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
145-
when running on Mac and the data-loaders will not run in parallel. In
146-
addition, since `python 3.8`, [the `multiprocessing` library start method
147-
changed from `'fork'` to
149+
when running on Mac, present in previous versions, persist on this release and
150+
the data-loaders will not run in parallel. In addition, since `python 3.8`,
151+
[the `multiprocessing` library start method changed from `'fork'` to
148152
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
149-
This also affects the data-loaders (for any `torch` version) and they will not
150-
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
151-
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
153+
This also affects the data-loaders (for any `torch` version) and they will
154+
not run in parallel. Therefore, for Mac users I recommend using `python 3.6`
155+
or `3.7` and `torch <= 1.6` (with the corresponding, consistent version of
152156
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
153157
versioning in the `setup.py` file since I expect that all these issues are
154-
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
155-
directly from github, downgrade `torch` and `torchvision` manually:
158+
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip
159+
or directly from github, downgrade `torch` and `torchvision` manually:
156160

157161
```bash
158162
pip install pytorch-widedeep
@@ -167,16 +171,13 @@ Binary classification with the [adult
167171
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
168172
using `Wide` and `DeepDense` and defaults settings.
169173

170-
171-
```python
172-
```
173-
174174
Building a wide (linear) and deep model with ``pytorch-widedeep``:
175175

176176
```python
177177

178178
import pandas as pd
179179
import numpy as np
180+
import torch
180181
from sklearn.model_selection import train_test_split
181182

182183
from pytorch_widedeep import Trainer
@@ -248,8 +249,29 @@ X_wide_te = wide_preprocessor.transform(df_test)
248249
X_tab_te = tab_preprocessor.transform(df_test)
249250
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
250251

251-
# save and load
252-
trainer.save_model("model_weights/model.t")
252+
# Save and load
253+
254+
# Option 1: this will also save training history and lr history if the
255+
# LRHistory callback is used
256+
trainer.save(path="model_weights", save_state_dict=True)
257+
258+
# Option 2: save as any other torch model
259+
torch.save(model.state_dict(), "model_weights/wd_model.pt")
260+
261+
# From here in advance, Option 1 or 2 are the same. I assume the user has
262+
# prepared the data and defined the new model components:
263+
# 1. Build the model
264+
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
265+
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
266+
267+
# 2. Instantiate the trainer
268+
trainer_new = Trainer(
269+
model_new,
270+
objective="binary",
271+
)
272+
273+
# 3. Either start the fit or directly predict
274+
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)
253275
```
254276

255277
Of course, one can do **much more**. See the Examples folder, the

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.8
1+
1.0.0

docs/callbacks.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ Here are the 4 callbacks available in ``pytorch-widedepp``: ``History``,
1010
.. autoclass:: pytorch_widedeep.callbacks.History
1111
:members:
1212

13+
.. autoclass:: pytorch_widedeep.callbacks.LRShedulerCallback
14+
:members:
15+
1316
.. autoclass:: pytorch_widedeep.callbacks.LRHistory
1417
:members:
1518

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
# Remove the prompt when copying examples
112112
copybutton_prompt_text = ">>> "
113113

114-
autoclass_content = "init" # 'both'
114+
# autoclass_content = "init" # 'both'
115115
autodoc_member_order = "bysource"
116116
# autodoc_default_flags = ["show-inheritance"]
117117

docs/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ them to address different problems
1313
* `Regression with Images and Text <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/05_Regression_with_Images_and_Text.ipynb>`__
1414
* `FineTune routines <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/06_FineTune_and_WarmUp_Model_Components.ipynb>`__
1515
* `Custom Components <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/07_Custom_Components.ipynb>`__
16+
* `Save and Load Model and Artifacts <https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples/08_save_and_load_model_and_artifacts.ipynb>`__

docs/index.rst

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ deeptabular, deeptext and deepimage, can be used independently** and in
9090
isolation. For example, one could use only ``wide``, which is in simply a
9191
linear model. In fact, one of the most interesting offerings of
9292
``pytorch-widedeep`` is the ``deeptabular`` component. Currently,
93-
``pytorch-widedeep`` offers 3 models for that component:
93+
``pytorch-widedeep`` offers 4 models for that component:
9494

9595
1. ``TabMlp``: this is almost identical to the `tabular
9696
model <https://docs.fast.ai/tutorial.tabular.html>`_ in the fantastic
@@ -101,12 +101,14 @@ features, and passed then through a MLP.
101101
2. ``TabRenset``: This is similar to the previous model but the embeddings are
102102
passed through a series of ResNet blocks built with dense layers.
103103

104-
3. ``TabTransformer``: Details on the TabTransformer can be found in:
105-
`TabTransformer: Tabular Data Modeling Using Contextual
106-
Embeddings <https://arxiv.org/pdf/2012.06678.pdf>`_.
104+
3. ``Tabnet``: Details on TabNet can be found in: `TabNet: Attentive
105+
Interpretable Tabular Learning <https://arxiv.org/abs/1908.07442>`_.
107106

107+
4. ``TabTransformer``: Details on the TabTransformer can be found in:
108+
`TabTransformer: Tabular Data Modeling Using Contextual Embeddings
109+
<https://arxiv.org/pdf/2012.06678.pdf>`_.
108110

109-
For details on these 3 models and their options please see the examples in the
111+
For details on these 4 models and their options please see the examples in the
110112
Examples folder and the documentation.
111113

112114
Finally, while I recommend using the ``wide`` and ``deeptabular`` models in

docs/losses.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ Losses
22
======
33

44
``pytorch-widedeep`` accepts a number of losses and objectives that can be
5-
passed to the ``Trainer`` class via the ``str`` parameter ``objective`` (see
6-
``pytorch-widedeep.training.Trainer``). For most cases the loss function that
7-
``pytorch-widedeep`` will use internally is already implemented in Pytorch.
5+
passed to the ``Trainer`` class via the parameter ``objective``
6+
(see ``pytorch-widedeep.training.Trainer``). For most cases the loss function
7+
that ``pytorch-widedeep`` will use internally is already implemented in
8+
Pytorch.
89

910
In addition, ``pytorch-widedeep`` implements four "custom" loss functions.
1011
These are described below for completion since, as I mentioned before, they

docs/model_components.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ This module contains the four main components that will comprise a Wide and
55
Deep model, and the ``WideDeep`` "constructor" class. These four components
66
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
77

8-
.. note:: ``TabMlp``, ``TabResnet`` and ``TabTransformer`` can all be used
9-
as the ``deeptabular`` component of the model and simply represent
10-
different alternatives
8+
.. note:: ``TabMlp``, ``TabResnet``, ``TabNet`` and ``TabTransformer`` can all
9+
be used as the ``deeptabular`` component of the model and simply
10+
represent different alternatives
1111

1212
.. autoclass:: pytorch_widedeep.models.wide.Wide
1313
:exclude-members: forward
@@ -21,6 +21,10 @@ are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
2121
:exclude-members: forward
2222
:members:
2323

24+
.. autoclass:: pytorch_widedeep.models.tabnet.tab_net.TabNet
25+
:exclude-members: forward
26+
:members:
27+
2428
.. autoclass:: pytorch_widedeep.models.tab_transformer.TabTransformer
2529
:exclude-members: forward
2630
:members:

0 commit comments

Comments
 (0)