Skip to content

Commit ae3b84f

Browse files
committed
Adjusted docs to be consistent with the README file
1 parent 65ed58f commit ae3b84f

File tree

25 files changed

+1477
-250
lines changed

25 files changed

+1477
-250
lines changed

README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ class MyModelFuser(BaseWDModelComponent):
558558
def output_dim(self):
559559
return self.output_units
560560

561+
561562
deephead = MyModelFuser(
562563
tab_incoming_dim=tab_mlp.output_dim,
563564
text_incoming_dim=models_fuser.output_dim,
@@ -586,6 +587,56 @@ trainer.fit(
586587
)
587588
```
588589

590+
**7. Tabular with a multi-target loss**
591+
592+
This one is "a bonus" to illustrate the use of multi-target losses, more than
593+
actually a different architecture.
594+
595+
<p align="center">
596+
<img width="200" src="docs/figures/arch_7.png">
597+
</p>
598+
599+
600+
```python
601+
from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
602+
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
603+
from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
604+
from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
605+
from pytorch_widedeep import Trainer
606+
607+
# let's add a second target to the dataframe
608+
df["target2"] = [random.choice([0, 1]) for _ in range(100)]
609+
610+
# Tabular
611+
tab_preprocessor = TabPreprocessor(
612+
embed_cols=["city", "name"], continuous_cols=["age", "height"]
613+
)
614+
X_tab = tab_preprocessor.fit_transform(df)
615+
tab_mlp = TabMlp(
616+
column_idx=tab_preprocessor.column_idx,
617+
cat_embed_input=tab_preprocessor.cat_embed_input,
618+
continuous_cols=tab_preprocessor.continuous_cols,
619+
mlp_hidden_dims=[64, 32],
620+
)
621+
622+
# 'pred_dim=2' because we have two binary targets. For other types of targets,
623+
# please, see the documentation
624+
model = WideDeep(deeptabular=tab_mlp, pred_dim=2).
625+
626+
loss = MultiTargetClassificationLoss(binary_config=[0, 1], reduction="mean")
627+
628+
# When a multi-target loss is used, 'custom_loss_function' must not be None.
629+
# See the docs
630+
trainer = Trainer(model, objective="multitarget", custom_loss_function=loss)
631+
632+
trainer.fit(
633+
X_tab=X_tab,
634+
target=df[["target", "target2"]].values,
635+
n_epochs=1,
636+
batch_size=32,
637+
)
638+
```
639+
589640
### The ``deeptabular`` component
590641

591642
It is important to emphasize again that **each individual component, `wide`,

docs/figures/arch_7.png

32.6 KB
Loading

examples/scripts/readme_snippets.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
WidePreprocessor,
2727
ImagePreprocessor,
2828
)
29+
from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
2930
from pytorch_widedeep.models._base_wd_model_component import (
3031
BaseWDModelComponent,
3132
)
@@ -404,3 +405,35 @@ def output_dim(self):
404405
n_epochs=1,
405406
batch_size=32,
406407
)
408+
409+
410+
# 7. Simply Tabular with a multi-target loss
411+
412+
# let's add a second target to the dataframe
413+
df["target2"] = [random.choice([0, 1]) for _ in range(100)]
414+
415+
# Tabular
416+
tab_preprocessor = TabPreprocessor(
417+
embed_cols=["city", "name"], continuous_cols=["age", "height"]
418+
)
419+
X_tab = tab_preprocessor.fit_transform(df)
420+
tab_mlp = TabMlp(
421+
column_idx=tab_preprocessor.column_idx,
422+
cat_embed_input=tab_preprocessor.cat_embed_input,
423+
continuous_cols=tab_preprocessor.continuous_cols,
424+
mlp_hidden_dims=[64, 32],
425+
)
426+
427+
# 2 binary targets. For other types of targets, please, see the documentation
428+
model = WideDeep(deeptabular=tab_mlp, pred_dim=2)
429+
430+
loss = MultiTargetClassificationLoss(binary_config=[0, 1], reduction="mean")
431+
432+
trainer = Trainer(model, objective="multitarget", custom_loss_function=loss)
433+
434+
trainer.fit(
435+
X_tab=X_tab,
436+
target=df[["target", "target2"]].values,
437+
n_epochs=1,
438+
batch_size=32,
439+
)

mkdocs/site/index.html

Lines changed: 550 additions & 35 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)