1313
1414# pytorch-widedeep
1515
16- A flexible package to combine tabular data with text and images using wide and
17- deep models.
16+ A flexible package to use Deep Learning with tabular data, text and images
17+ using wide and deep models.
1818
1919** Documentation:** [ https://pytorch-widedeep.readthedocs.io ] ( https://pytorch-widedeep.readthedocs.io/en/latest/index.html )
2020
2121** Companion posts:** [ infinitoml] ( https://jrzaurin.github.io/infinitoml/ )
2222
2323### Introduction
2424
25- ` pytorch-widedeep ` is based on Google's Wide and Deep Algorithm. Details of
26- the original algorithm can be found
27- [ here] ( https://www.tensorflow.org/tutorials/wide_and_deep ) , and the nice
28- research paper can be found [ here] ( https://arxiv.org/abs/1606.07792 ) .
25+ ` pytorch-widedeep ` is based on Google's Wide and Deep Algorithm, [ Wide & Deep
26+ Learning for Recommender Systems] ( https://arxiv.org/abs/1606.07792 ) .
2927
3028In general terms, ` pytorch-widedeep ` is a package to use deep learning with
3129tabular data. In particular, is intended to facilitate the combination of text
3230and images with corresponding tabular data using wide and deep models. With
33- that in mind there are two architectures that can be implemented with just a
34- few lines of code.
31+ that in mind there are a number of architectures that can be implemented with
32+ just a few lines of code. The main components of those architectures are shown
33+ in the Figure below:
3534
36- ### Architectures
37-
38- ** Architecture 1** :
3935
4036<p align =" center " >
41- <img width =" 750 " src =" docs/figures/architecture_1 .png " >
37+ <img width =" 750 " src =" docs/figures/widedeep_arch .png " >
4238</p >
4339
44- Architecture 1 combines the ` Wide ` , Linear model with the outputs from the
45- ` DeepDense ` or ` DeepDenseResnet ` , ` DeepText ` and ` DeepImage ` components
46- connected to a final output neuron or neurons, depending on whether we are
47- performing a binary classification or regression, or a multi-class
48- classification. The components within the faded-pink rectangles are
49- concatenated.
40+ The dashed boxes in the figure represent optional, overall components, and the
41+ dashed lines/arrows indicate the corresponding connections, depending on
42+ whether or not certain components are present. For example, the dashed,
43+ blue-lines indicate that the `` deeptabular `` , `` deeptext `` and `` deepimage ``
44+ components are connected directly to the output neuron or neurons (depending
45+ on whether we are performing a binary classification or regression, or a
46+ multi-class classification) if the optional `` deephead `` is not present.
47+ Finally, the components within the faded-pink rectangle are concatenated.
48+
49+ Note that it is not possible to illustrate the number of possible
50+ architectures and components available in `` pytorch-widedeep `` in one Figure.
51+ Therefore, for more details on possible architectures (and more) please, see
52+ the
53+ [ documentation] ( (https://pytorch-widedeep.readthedocs.io/en/latest/index.html) ) ,
54+ or the Examples folders and the notebooks there.
5055
5156In math terms, and following the notation in the
52- [ paper] ( https://arxiv.org/abs/1606.07792 ) , Architecture 1 can be formulated
53- as:
57+ [ paper] ( https://arxiv.org/abs/1606.07792 ) , the expression for the architecture
58+ without a `` deephead `` component can be formulated as:
5459
5560<p align =" center " >
5661 <img width =" 500 " src =" docs/figures/architecture_1_math.png " >
@@ -67,43 +72,47 @@ the constituent features (“gender=female” and “language=en”) are all 1,
6772otherwise".*
6873
6974
70- ** Architecture 2**
71-
72- <p align =" center " >
73- <img width =" 750 " src =" docs/figures/architecture_2.png " >
74- </p >
75-
76- Architecture 2 combines the ` Wide ` , Linear model with the Deep components of
77- the model connected to the output neuron(s), after the different Deep
78- components have been themselves combined through a FC-Head (that I refer as
79- ` deephead ` ).
80-
81- In math terms, and following the notation in the
82- [ paper] ( https://arxiv.org/abs/1606.07792 ) , Architecture 2 can be formulated
83- as:
75+ While if there is a `` deephead `` component, the previous expression turns
76+ into:
8477
8578<p align =" center " >
8679 <img width =" 300 " src =" docs/figures/architecture_2_math.png " >
8780</p >
8881
89- Note that each individual component, ` wide ` , ` deepdense ` (either ` DeepDense `
90- or ` DeepDenseResnet ` ), ` deeptext ` and ` deepimage ` , can be used independently
91- and in isolation. For example, one could use only ` wide ` , which is in simply a
92- linear model.
93-
94- On the other hand, while I recommend using the ` Wide ` and ` DeepDense ` (or
95- ` DeepDenseResnet ` ) classes in ` pytorch-widedeep ` to build the ` wide ` and
96- ` deepdense ` component, it is very likely that users will want to use their own
97- models in the case of the ` deeptext ` and ` deepimage ` components. That is
98- perfectly possible as long as the the custom models have an attribute called
99- ` output_dim ` with the size of the last layer of activations, so that
100- ` WideDeep ` can be constructed
101-
102- ` pytorch-widedeep ` includes standard text (stack of LSTMs) and image
82+ It is important to emphasize that ** each individual component, ` wide ` ,
83+ ` deeptabular ` , ` deeptext ` and ` deepimage ` , can be used independently** and in
84+ isolation. For example, one could use only ` wide ` , which is in simply a linear
85+ model. In fact, one of the most interesting functionalities
86+ in`` pytorch-widedeep `` is the `` deeptabular `` component. Currently,
87+ `` pytorch-widedeep `` offers 3 models for that component:
88+
89+ 1 . `` TabMlp `` : this is almost identical to the [ tabular
90+ model] ( https://docs.fast.ai/tutorial.tabular.html ) in the fantastic
91+ [ fastai] ( https://docs.fast.ai/ ) library, and consists simply in embeddings
92+ representing the categorical features, concatenated with the continuous
93+ features, and passed then through a MLP.
94+
95+ 2 . `` TabRenset `` : This is similar to the previous model but the embeddings are
96+ passed through a series of ResNet blocks built with dense layers.
97+
98+ 3 . `` TabTransformer `` : Details on the TabTransformer can be found in:
99+ [ TabTransformer: Tabular Data Modeling Using Contextual
100+ Embeddings] ( https://arxiv.org/pdf/2012.06678.pdf )
101+
102+
103+ For details on these 3 models and their options please see the examples in the
104+ Examples folder and the documentation.
105+
106+ Finally, while I recommend using the `` wide `` and `` deeptabular `` models in
107+ `` pytorch-widedeep `` it is very likely that users will want to use their own
108+ models for the `` deeptext `` and `` deepimage `` components. That is perfectly
109+ possible as long as the the custom models have an attribute called
110+ `` output_dim `` with the size of the last layer of activations, so that
111+ `` WideDeep `` can be constructed. Again, examples on how to use custom
112+ components can be found in the Examples folder. Just in case
113+ `` pytorch-widedeep `` includes standard text (stack of LSTMs) and image
103114(pre-trained ResNets or stack of CNNs) models.
104115
105- See the examples folder or the docs for more information.
106-
107116
108117### Installation
109118
@@ -130,8 +139,8 @@ cd pytorch-widedeep
130139pip install -e .
131140```
132141
133- ** Important note for Mac users** : at the time of writing (Dec -2020) the latest
134- ` torch ` release is ` 1.7 ` . This release has some
142+ ** Important note for Mac users** : at the time of writing (Feb -2020) the latest
143+ ` torch ` release is ` 1.7.1 ` . This release has some
135144[ issues] ( https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206 )
136145when running on Mac and the data-loaders will not run in parallel. In
137146addition, since ` python 3.8 ` , [ the ` multiprocessing ` library start method
@@ -158,17 +167,26 @@ Binary classification with the [adult
158167dataset] ( [adult](https://www.kaggle.com/wenruliu/adult-income-dataset) )
159168using ` Wide ` and ` DeepDense ` and defaults settings.
160169
170+
171+ ``` python
172+ ```
173+
174+ Building a wide (linear) and deep model with `` pytorch-widedeep `` :
175+
161176``` python
177+
162178import pandas as pd
163179import numpy as np
164180from sklearn.model_selection import train_test_split
165181
166- from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
167- from pytorch_widedeep.models import Wide, DeepDense, WideDeep
182+ from pytorch_widedeep import Trainer
183+ from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
184+ from pytorch_widedeep.models import Wide, TabMlp, WideDeep
168185from pytorch_widedeep.metrics import Accuracy
169186
170- # these next 4 lines are not directly related to pytorch-widedeep. I assume
171- # you have downloaded the dataset and place it in a dir called data/adult/
187+ # the following 4 lines are not directly related to ``pytorch-widedeep``. I
188+ # assume you have downloaded the dataset and place it in a dir called
189+ # data/adult/
172190df = pd.read_csv(" data/adult/adult.csv.zip" )
173191df[" income_label" ] = (df[" income" ].apply(lambda x : " >50K" in x)).astype(int )
174192df.drop(" income" , axis = 1 , inplace = True )
@@ -197,61 +215,46 @@ target_col = "income_label"
197215target = df_train[target_col].values
198216
199217# wide
200- preprocess_wide = WidePreprocessor(wide_cols = wide_cols, crossed_cols = cross_cols)
201- X_wide = preprocess_wide .fit_transform(df_train)
218+ wide_preprocessor = WidePreprocessor(wide_cols = wide_cols, crossed_cols = cross_cols)
219+ X_wide = wide_preprocessor .fit_transform(df_train)
202220wide = Wide(wide_dim = np.unique(X_wide).shape[0 ], pred_dim = 1 )
203221
204- # deepdense
205- preprocess_deep = DensePreprocessor (embed_cols = embed_cols, continuous_cols = cont_cols)
206- X_deep = preprocess_deep .fit_transform(df_train)
207- deepdense = DeepDense (
208- hidden_layers = [64 , 32 ],
209- deep_column_idx = preprocess_deep.deep_column_idx ,
210- embed_input = preprocess_deep .embeddings_input,
222+ # deeptabular
223+ tab_preprocessor = TabPreprocessor (embed_cols = embed_cols, continuous_cols = cont_cols)
224+ X_tab = tab_preprocessor .fit_transform(df_train)
225+ deeptabular = TabMlp (
226+ mlp_hidden_dims = [64 , 32 ],
227+ column_idx = tab_preprocessor.column_idx ,
228+ embed_input = tab_preprocessor .embeddings_input,
211229 continuous_cols = cont_cols,
212230)
213- # # To use DeepDenseResnet as the deepdense component simply:
214- # from pytorch_widedeep.models import DeepDenseResnet:
215- # deepdense = DeepDenseResnet(
216- # blocks=[64, 32],
217- # deep_column_idx=preprocess_deep.deep_column_idx,
218- # embed_input=preprocess_deep.embeddings_input,
219- # continuous_cols=cont_cols,
220- # )
221-
222- # build, compile and fit
223- model = WideDeep(wide = wide, deepdense = deepdense)
224- model.compile(method = " binary" , metrics = [Accuracy])
225- model.fit(
231+
232+ # wide and deep
233+ model = WideDeep(wide = wide, deeptabular = deeptabular)
234+
235+ # train the model
236+ trainer = Trainer(model, objective = " binary" , metrics = [Accuracy])
237+ trainer.fit(
226238 X_wide = X_wide,
227- X_deep = X_deep ,
239+ X_tab = X_tab ,
228240 target = target,
229241 n_epochs = 5 ,
230242 batch_size = 256 ,
231243 val_split = 0.1 ,
232244)
233245
234246# predict
235- X_wide_te = preprocess_wide.transform(df_test)
236- X_deep_te = preprocess_deep.transform(df_test)
237- preds = model.predict(X_wide = X_wide_te, X_deep = X_deep_te)
238-
239- # # save and load
240- # torch.save(model, "model_weights/model.t")
241- # model = torch.load("model_weights/model.t")
242-
243- # # or via state dictionaries
244- # torch.save(model.state_dict(), PATH)
245- # model = WideDeep(*args)
246- # model.load_state_dict(torch.load(PATH))
247+ X_wide_te = wide_preprocessor.transform(df_test)
248+ X_tab_te = tab_preprocessor.transform(df_test)
249+ preds = trainer.predict(X_wide = X_wide_te, X_tab = X_tab_te)
250+
251+ # save and load
252+ trainer.save_model(" model_weights/model.t" )
247253```
248254
249- Of course, one can do much more, such as using different initializations,
250- optimizers or learning rate schedulers for each component of the overall
251- model. Adding FC-Heads to the Text and Image components. Using the [ Focal
252- Loss] ( https://arxiv.org/abs/1708.02002 ) , warming up individual components
253- before joined training, etc. See the ` examples ` or the ` docs ` folders for a
254- better understanding of the content of the package and its functionalities.
255+ Of course, one can do ** much more** . See the Examples folder, the
256+ documentation or the companion posts for a better understanding of the content
257+ of the package and its functionalities.
255258
256259### Testing
257260
0 commit comments