99[ ![ Maintenance] ( https://img.shields.io/badge/Maintained%3F-yes-green.svg )] ( https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity )
1010[ ![ contributions welcome] ( https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat )] ( https://github.com/jrzaurin/pytorch-widedeep/issues )
1111[ ![ codecov] ( https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg )] ( https://codecov.io/gh/jrzaurin/pytorch-widedeep )
12- [ ![ Python 3.6 3.7 3.8] ( https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg )] ( https://www.python.org/ )
12+ [ ![ Python 3.6 3.7 3.8 3.9 ] ( https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20%7C%203.9 -blue.svg )] ( https://www.python.org/ )
1313
1414# pytorch-widedeep
1515
@@ -18,12 +18,13 @@ using wide and deep models.
1818
1919** Documentation:** [ https://pytorch-widedeep.readthedocs.io ] ( https://pytorch-widedeep.readthedocs.io/en/latest/index.html )
2020
21- ** Companion posts:** [ infinitoml] ( https://jrzaurin.github.io/infinitoml/ )
21+ ** Companion posts and tutorials:** [ infinitoml] ( https://jrzaurin.github.io/infinitoml/ )
22+
23+ ** Experiments and comparisson with ` LightGBM ` ** : [ TabularDL vs LightGBM] ( https://github.com/jrzaurin/tabulardl-benchmark )
2224
2325### Introduction
2426
25- ` pytorch-widedeep ` is based on Google's Wide and Deep Algorithm, [ Wide & Deep
26- Learning for Recommender Systems] ( https://arxiv.org/abs/1606.07792 ) .
27+ `` pytorch-widedeep `` is based on Google's [ Wide and Deep Algorithm] ( https://arxiv.org/abs/1606.07792 )
2728
2829In general terms, ` pytorch-widedeep ` is a package to use deep learning with
2930tabular data. In particular, is intended to facilitate the combination of text
@@ -84,7 +85,7 @@ It is important to emphasize that **each individual component, `wide`,
8485isolation. For example, one could use only ` wide ` , which is in simply a linear
8586model. In fact, one of the most interesting functionalities
8687in`` pytorch-widedeep `` is the `` deeptabular `` component. Currently,
87- `` pytorch-widedeep `` offers 3 models for that component:
88+ `` pytorch-widedeep `` offers 4 models for that component:
8889
89901 . `` TabMlp `` : this is almost identical to the [ tabular
9091model] ( https://docs.fast.ai/tutorial.tabular.html ) in the fantastic
@@ -95,12 +96,15 @@ features, and passed then through a MLP.
95962 . `` TabRenset `` : This is similar to the previous model but the embeddings are
9697passed through a series of ResNet blocks built with dense layers.
9798
98- 3 . `` TabTransformer `` : Details on the TabTransformer can be found in:
99+ 3 . `` Tabnet `` : Details on TabNet can be found in:
100+ [ TabNet: Attentive Interpretable Tabular Learning] ( https://arxiv.org/abs/1908.07442 )
101+
102+ 4 . `` TabTransformer `` : Details on the TabTransformer can be found in:
99103[ TabTransformer: Tabular Data Modeling Using Contextual
100104Embeddings] ( https://arxiv.org/pdf/2012.06678.pdf )
101105
102106
103- For details on these 3 models and their options please see the examples in the
107+ For details on these 4 models and their options please see the examples in the
104108Examples folder and the documentation.
105109
106110Finally, while I recommend using the `` wide `` and `` deeptabular `` models in
@@ -139,20 +143,20 @@ cd pytorch-widedeep
139143pip install -e .
140144```
141145
142- ** Important note for Mac users** : at the time of writing (Feb-2020 ) the latest
143- ` torch ` release is ` 1.7.1 ` . This release has some
146+ ** Important note for Mac users** : at the time of writing (June-2021 ) the
147+ latest ` torch ` release is ` 1.9 ` . Some past
144148[ issues] ( https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206 )
145- when running on Mac and the data-loaders will not run in parallel. In
146- addition, since ` python 3.8 ` , [ the ` multiprocessing ` library start method
147- changed from ` 'fork' ` to
149+ when running on Mac, present in previous versions, persist on this release and
150+ the data-loaders will not run in parallel. In addition, since ` python 3.8 ` ,
151+ [ the ` multiprocessing ` library start method changed from ` 'fork' ` to
148152` 'spawn' ` ] ( https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods ) .
149- This also affects the data-loaders (for any ` torch ` version) and they will not
150- run in parallel. Therefore, for Mac users I recommend using ` python 3.6 ` or
151- ` 3.7 ` and ` torch <= 1.6 ` (with the corresponding, consistent version of
153+ This also affects the data-loaders (for any ` torch ` version) and they will
154+ not run in parallel. Therefore, for Mac users I recommend using ` python 3.6 `
155+ or ` 3.7 ` and ` torch <= 1.6 ` (with the corresponding, consistent version of
152156` torchvision ` , e.g. ` 0.7.0 ` for ` torch 1.6 ` ). I do not want to force this
153157versioning in the ` setup.py ` file since I expect that all these issues are
154- fixed in the future. Therefore, after installing ` pytorch-widedeep ` via pip or
155- directly from github, downgrade ` torch ` and ` torchvision ` manually:
158+ fixed in the future. Therefore, after installing ` pytorch-widedeep ` via pip
159+ or directly from github, downgrade ` torch ` and ` torchvision ` manually:
156160
157161``` bash
158162pip install pytorch-widedeep
@@ -167,16 +171,13 @@ Binary classification with the [adult
167171dataset] ( [adult](https://www.kaggle.com/wenruliu/adult-income-dataset) )
168172using ` Wide ` and ` DeepDense ` and defaults settings.
169173
170-
171- ``` python
172- ```
173-
174174Building a wide (linear) and deep model with `` pytorch-widedeep `` :
175175
176176``` python
177177
178178import pandas as pd
179179import numpy as np
180+ import torch
180181from sklearn.model_selection import train_test_split
181182
182183from pytorch_widedeep import Trainer
@@ -248,8 +249,29 @@ X_wide_te = wide_preprocessor.transform(df_test)
248249X_tab_te = tab_preprocessor.transform(df_test)
249250preds = trainer.predict(X_wide = X_wide_te, X_tab = X_tab_te)
250251
251- # save and load
252- trainer.save_model(" model_weights/model.t" )
252+ # Save and load
253+
254+ # Option 1: this will also save training history and lr history if the
255+ # LRHistory callback is used
256+ trainer.save(path = " model_weights" , save_state_dict = True )
257+
258+ # Option 2: save as any other torch model
259+ torch.save(model.state_dict(), " model_weights/wd_model.pt" )
260+
261+ # From here in advance, Option 1 or 2 are the same. I assume the user has
262+ # prepared the data and defined the new model components:
263+ # 1. Build the model
264+ model_new = WideDeep(wide = wide, deeptabular = deeptabular)
265+ model_new.load_state_dict(torch.load(" model_weights/wd_model.pt" ))
266+
267+ # 2. Instantiate the trainer
268+ trainer_new = Trainer(
269+ model_new,
270+ objective = " binary" ,
271+ )
272+
273+ # 3. Either start the fit or directly predict
274+ preds = trainer_new.predict(X_wide = X_wide, X_tab = X_tab)
253275```
254276
255277Of course, one can do ** much more** . See the Examples folder, the
0 commit comments