@@ -558,6 +558,7 @@ class MyModelFuser(BaseWDModelComponent):
558558 def output_dim (self ):
559559 return self .output_units
560560
561+
561562deephead = MyModelFuser(
562563 tab_incoming_dim = tab_mlp.output_dim,
563564 text_incoming_dim = models_fuser.output_dim,
@@ -586,6 +587,56 @@ trainer.fit(
586587)
587588```
588589
590+ ** 7. Tabular with a multi-target loss**
591+
592+ This one is "a bonus" to illustrate the use of multi-target losses, more than
593+ actually a different architecture.
594+
595+ <p align =" center " >
596+ <img width =" 200 " src =" docs/figures/arch_7.png " >
597+ </p >
598+
599+
600+ ``` python
601+ from pytorch_widedeep.preprocessing import TabPreprocessor, TextPreprocessor, ImagePreprocessor
602+ from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep, ModelFuser, Vision
603+ from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
604+ from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
605+ from pytorch_widedeep import Trainer
606+
607+ # let's add a second target to the dataframe
608+ df[" target2" ] = [random.choice([0 , 1 ]) for _ in range (100 )]
609+
610+ # Tabular
611+ tab_preprocessor = TabPreprocessor(
612+ embed_cols = [" city" , " name" ], continuous_cols = [" age" , " height" ]
613+ )
614+ X_tab = tab_preprocessor.fit_transform(df)
615+ tab_mlp = TabMlp(
616+ column_idx = tab_preprocessor.column_idx,
617+ cat_embed_input = tab_preprocessor.cat_embed_input,
618+ continuous_cols = tab_preprocessor.continuous_cols,
619+ mlp_hidden_dims = [64 , 32 ],
620+ )
621+
622+ # 'pred_dim=2' because we have two binary targets. For other types of targets,
623+ # please, see the documentation
624+ model = WideDeep(deeptabular = tab_mlp, pred_dim = 2 ).
625+
626+ loss = MultiTargetClassificationLoss(binary_config = [0 , 1 ], reduction = " mean" )
627+
628+ # When a multi-target loss is used, 'custom_loss_function' must not be None.
629+ # See the docs
630+ trainer = Trainer(model, objective = " multitarget" , custom_loss_function = loss)
631+
632+ trainer.fit(
633+ X_tab = X_tab,
634+ target = df[[" target" , " target2" ]].values,
635+ n_epochs = 1 ,
636+ batch_size = 32 ,
637+ )
638+ ```
639+
589640### The `` deeptabular `` component
590641
591642It is important to emphasize again that ** each individual component, ` wide ` ,
0 commit comments