1515
1616# pytorch-widedeep
1717
18- A flexible package to use Deep Learning with tabular data, text and images
19- using wide and deep models.
18+ A flexible package for multimodal-deep-learning to combine tabular data with
19+ text and images using Wide and Deep models in Pytorch
2020
2121** Documentation:** [ https://pytorch-widedeep.readthedocs.io ] ( https://pytorch-widedeep.readthedocs.io/en/latest/index.html )
2222
2323** Companion posts and tutorials:** [ infinitoml] ( https://jrzaurin.github.io/infinitoml/ )
2424
25- ** Experiments and comparisson with ` LightGBM ` ** : [ TabularDL vs LightGBM] ( https://github.com/jrzaurin/tabulardl-benchmark )
25+ ** Experiments and comparison with ` LightGBM ` ** : [ TabularDL vs LightGBM] ( https://github.com/jrzaurin/tabulardl-benchmark )
2626
2727The content of this document is organized as follows:
2828
@@ -33,7 +33,8 @@ The content of this document is organized as follows:
3333
3434### Introduction
3535
36- `` pytorch-widedeep `` is based on Google's [ Wide and Deep Algorithm] ( https://arxiv.org/abs/1606.07792 )
36+ `` pytorch-widedeep `` is based on Google's [ Wide and Deep Algorithm] ( https://arxiv.org/abs/1606.07792 ) ,
37+ adjusted for multi-modal datasets
3738
3839In general terms, ` pytorch-widedeep ` is a package to use deep learning with
3940tabular data. In particular, is intended to facilitate the combination of text
@@ -89,15 +90,11 @@ into:
8990 <img width =" 300 " src =" docs/figures/architecture_2_math.png " >
9091</p >
9192
92- I recommend using the `` wide `` and `` deeptabular `` models in
93- `` pytorch-widedeep `` . However it is very likely that users will want to use
94- their own models for the `` deeptext `` and `` deepimage `` components. That is
95- perfectly possible as long as the the custom models have an attribute called
93+ It is perfectly possible to use custom models (and not necessarily those in
94+ the library) as long as the the custom models have an attribute called
9695`` output_dim `` with the size of the last layer of activations, so that
97- `` WideDeep `` can be constructed. Again, examples on how to use custom
98- components can be found in the Examples folder. Just in case
99- `` pytorch-widedeep `` includes standard text (stack of LSTMs) and image
100- (pre-trained ResNets or stack of CNNs) models.
96+ `` WideDeep `` can be constructed. Examples on how to use custom components can
97+ be found in the Examples folder.
10198
10299### The `` deeptabular `` component
103100
@@ -110,15 +107,17 @@ its own, i.e. what one might normally refer as Deep Learning for Tabular
110107Data. Currently, `` pytorch-widedeep `` offers the following different models
111108for that component:
112109
113-
110+ 0 . ** Wide** : a simple linear model where the nonlinearities are captured via
111+ cross-product transformations, as explained before.
1141121 . ** TabMlp** : a simple MLP that receives embeddings representing the
115- categorical features, concatenated with the continuous features.
113+ categorical features, concatenated with the continuous features, which can
114+ also be embedded.
1161152 . ** TabResnet** : similar to the previous model but the embeddings are
117116passed through a series of ResNet blocks built with dense layers.
1181173 . ** TabNet** : details on TabNet can be found in
119118[ TabNet: Attentive Interpretable Tabular Learning] ( https://arxiv.org/abs/1908.07442 )
120119
121- And the `` Tabformer `` family, i.e. Transformers for Tabular data:
120+ The `` Tabformer `` family, i.e. Transformers for Tabular data:
122121
1231224 . ** TabTransformer** : details on the TabTransformer can be found in
124123[ TabTransformer: Tabular Data Modeling Using Contextual Embeddings] ( https://arxiv.org/pdf/2012.06678.pdf ) .
@@ -133,12 +132,19 @@ on the Fasformer can be found in
133132the Perceiver can be found in
134133[ Perceiver: General Perception with Iterative Attention] ( https://arxiv.org/abs/2103.03206 )
135134
135+ And probabilistic DL models for tabular data based on
136+ [ Weight Uncertainty in Neural Networks] ( https://arxiv.org/abs/1505.05424 ) :
137+
138+ 9 . ** BayesianWide** : Probabilistic adaptation of the ` Wide ` model.
139+ 10 . ** BayesianTabMlp** : Probabilistic adaptation of the ` TabMlp ` model
140+
136141Note that while there are scientific publications for the TabTransformer,
137142SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
138143adaptation of those algorithms for tabular data.
139144
140- For details on these models and their options please see the examples in the
141- Examples folder and the documentation.
145+ For details on these models (and all the other models in the library for the
146+ different data modes) and their corresponding options please see the examples
147+ in the Examples folder and the documentation.
142148
143149### Installation
144150
@@ -165,27 +171,6 @@ cd pytorch-widedeep
165171pip install -e .
166172```
167173
168- ** Important note for Mac users** : at the time of writing the latest ` torch `
169- release is ` 1.9 ` . Some past [ issues] ( https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206 )
170- when running on Mac, present in previous versions, persist on this release
171- and the data-loaders will not run in parallel. In addition, since `python
172- 3.8` , [the ` multiprocessing` library start method changed from ` 'fork'` to ` 'spawn'`] ( https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods ) .
173- This also affects the data-loaders (for any ` torch ` version) and they will
174- not run in parallel. Therefore, for Mac users I recommend using ` python 3.7 `
175- and ` torch <= 1.6 ` (with the corresponding, consistent
176- version of ` torchvision ` , e.g. ` 0.7.0 ` for ` torch 1.6 ` ). I do not want to
177- force this versioning in the ` setup.py ` file since I expect that all these
178- issues are fixed in the future. Therefore, after installing
179- ` pytorch-widedeep ` via pip or directly from github, downgrade ` torch ` and
180- ` torchvision ` manually:
181-
182- ``` bash
183- pip install pytorch-widedeep
184- pip install torch==1.6.0 torchvision==0.7.0
185- ```
186-
187- None of these issues affect Linux users.
188-
189174### Quick start
190175
191176Binary classification with the [ adult
@@ -195,7 +180,6 @@ using `Wide` and `DeepDense` and defaults settings.
195180Building a wide (linear) and deep model with `` pytorch-widedeep `` :
196181
197182``` python
198-
199183import pandas as pd
200184import numpy as np
201185import torch
@@ -205,16 +189,15 @@ from pytorch_widedeep import Trainer
205189from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
206190from pytorch_widedeep.models import Wide, TabMlp, WideDeep
207191from pytorch_widedeep.metrics import Accuracy
192+ from pytorch_widedeep.datasets import load_adult
193+
208194
209- # the following 4 lines are not directly related to ``pytorch-widedeep``. I
210- # assume you have downloaded the dataset and place it in a dir called
211- # data/adult/
212- df = pd.read_csv(" data/adult/adult.csv.zip" )
195+ df = load_adult(as_frame = True )
213196df[" income_label" ] = (df[" income" ].apply(lambda x : " >50K" in x)).astype(int )
214197df.drop(" income" , axis = 1 , inplace = True )
215198df_train, df_test = train_test_split(df, test_size = 0.2 , stratify = df.income_label)
216199
217- # prepare wide, crossed, embedding and continuous columns
200+ # Define the 'column set up'
218201wide_cols = [
219202 " education" ,
220203 " relationship" ,
@@ -223,49 +206,53 @@ wide_cols = [
223206 " native-country" ,
224207 " gender" ,
225208]
226- cross_cols = [(" education" , " occupation" ), (" native-country" , " occupation" )]
227- embed_cols = [
228- (" education" , 16 ),
229- (" workclass" , 16 ),
230- (" occupation" , 16 ),
231- (" native-country" , 32 ),
232- ]
233- cont_cols = [" age" , " hours-per-week" ]
234- target_col = " income_label"
209+ crossed_cols = [(" education" , " occupation" ), (" native-country" , " occupation" )]
235210
236- # target
237- target = df_train[target_col].values
211+ cat_embed_cols = [
212+ " workclass" ,
213+ " education" ,
214+ " marital-status" ,
215+ " occupation" ,
216+ " relationship" ,
217+ " race" ,
218+ " gender" ,
219+ " capital-gain" ,
220+ " capital-loss" ,
221+ " native-country" ,
222+ ]
223+ continuous_cols = [" age" , " hours-per-week" ]
224+ target = " income_label"
225+ target = df_train[target].values
238226
239- # wide
240- wide_preprocessor = WidePreprocessor(wide_cols = wide_cols, crossed_cols = cross_cols )
227+ # prepare the data
228+ wide_preprocessor = WidePreprocessor(wide_cols = wide_cols, crossed_cols = crossed_cols )
241229X_wide = wide_preprocessor.fit_transform(df_train)
242- wide = Wide(wide_dim = np.unique(X_wide).shape[0 ], pred_dim = 1 )
243230
244- # deeptabular
245- tab_preprocessor = TabPreprocessor(embed_cols = embed_cols, continuous_cols = cont_cols)
231+ tab_preprocessor = TabPreprocessor(
232+ cat_embed_cols = cat_embed_cols, continuous_cols = continuous_cols # type: ignore [ arg -type ]
233+ )
246234X_tab = tab_preprocessor.fit_transform(df_train)
247- deeptabular = TabMlp(
248- mlp_hidden_dims = [64 , 32 ],
235+
236+ # build the model
237+ wide = Wide(input_dim = np.unique(X_wide).shape[0 ], pred_dim = 1 )
238+ tab_mlp = TabMlp(
249239 column_idx = tab_preprocessor.column_idx,
250- embed_input = tab_preprocessor.embeddings_input ,
251- continuous_cols = cont_cols ,
240+ cat_embed_input = tab_preprocessor.cat_embed_input ,
241+ continuous_cols = continuous_cols ,
252242)
243+ model = WideDeep(wide = wide, deeptabular = tab_mlp)
253244
254- # wide and deep
255- model = WideDeep(wide = wide, deeptabular = deeptabular)
256-
257- # train the model
245+ # train and validate
258246trainer = Trainer(model, objective = " binary" , metrics = [Accuracy])
259247trainer.fit(
260248 X_wide = X_wide,
261249 X_tab = X_tab,
262250 target = target,
263251 n_epochs = 5 ,
264252 batch_size = 256 ,
265- val_split = 0.1 ,
266253)
267254
268- # predict
255+ # predict on test
269256X_wide_te = wide_preprocessor.transform(df_test)
270257X_tab_te = tab_preprocessor.transform(df_test)
271258preds = trainer.predict(X_wide = X_wide_te, X_tab = X_tab_te)
@@ -282,14 +269,11 @@ torch.save(model.state_dict(), "model_weights/wd_model.pt")
282269# From here in advance, Option 1 or 2 are the same. I assume the user has
283270# prepared the data and defined the new model components:
284271# 1. Build the model
285- model_new = WideDeep(wide = wide, deeptabular = deeptabular )
272+ model_new = WideDeep(wide = wide, deeptabular = tab_mlp )
286273model_new.load_state_dict(torch.load(" model_weights/wd_model.pt" ))
287274
288275# 2. Instantiate the trainer
289- trainer_new = Trainer(
290- model_new,
291- objective = " binary" ,
292- )
276+ trainer_new = Trainer(model_new, objective = " binary" )
293277
294278# 3. Either start the fit or directly predict
295279preds = trainer_new.predict(X_wide = X_wide, X_tab = X_tab)
0 commit comments