@@ -160,12 +160,21 @@ class Trainer(BaseTrainer):
160160 - **num_workers**: `int`<br/>
161161 number of workers to be used internally by the data loaders
162162
163+ - **use_multi_gpu**: `bool`<br/>
164+ If True, the model will be trained on multiple GPUs. This is
165+ only supported for the `deeptabular` component.
166+
167+ NOTE: this is an experimental feature and might not work as expected
168+ in some cases. In the particular case of the `Trainer` class, it has
169+ been extensively tested.
170+
163171 - **lambda_sparse**: `float`<br/>
164172 lambda sparse parameter in case the `deeptabular` component is `TabNet`
165173
166174 - **class_weight**: `List[float]`<br/>
167175 This is the `weight` or `pos_weight` parameter in
168176 `CrossEntropyLoss` and `BCEWithLogitsLoss`, depending on whether
177+
169178 - **reducelronplateau_criterion**: `str`
170179 This sets the criterion that will be used by the lr scheduler to
171180 take a step: One of _'loss'_ or _'metric'_. The ReduceLROnPlateau
@@ -834,22 +843,58 @@ def _do_finetune(
834843 r"""
835844 Simple wrap-up to individually fine-tune model components
836845 """
837- if self .model .deephead is not None :
846+
847+ if isinstance (self .model , torch .nn .DataParallel ):
848+ wide_component = (
849+ torch .nn .DataParallel (self .model .module .wide )
850+ if self .model .module .wide
851+ else None
852+ )
853+ deeptabular_component = (
854+ torch .nn .DataParallel (self .model .module .deeptabular )
855+ if self .model .module .deeptabular
856+ else None
857+ )
858+ deeptext_component = (
859+ torch .nn .DataParallel (self .model .module .deeptext )
860+ if self .model .module .deeptext
861+ else None
862+ )
863+ deepimage_component = (
864+ torch .nn .DataParallel (self .model .module .deepimage )
865+ if self .model .module .deepimage
866+ else None
867+ )
868+ deephead_component = (
869+ torch .nn .DataParallel (self .model .module .deephead )
870+ if self .model .module .deephead
871+ else None
872+ )
873+ else :
874+ wide_component = self .model .wide if self .model .wide else None
875+ deeptabular_component = (
876+ self .model .deeptabular if self .model .deeptabular else None
877+ )
878+ deeptext_component = self .model .deeptext if self .model .deeptext else None
879+ deepimage_component = self .model .deepimage if self .model .deepimage else None
880+ deephead_component = self .model .deephead if self .model .deephead else None
881+
882+ if deephead_component is not None :
838883 raise ValueError (
839884 "Currently warming up is only supported without a fully connected 'DeepHead'"
840885 )
841886
842887 finetuner = FineTune (self .loss_fn , self .metric , self .method , self .verbose ) # type: ignore[arg-type]
843- if self . model . wide :
844- finetuner .finetune_all (self . model . wide , "wide" , loader , n_epochs , max_lr )
888+ if wide_component :
889+ finetuner .finetune_all (wide_component , "wide" , loader , n_epochs , max_lr )
845890
846- if self . model . deeptabular :
891+ if deeptabular_component :
847892 if deeptabular_gradual :
848893 assert (
849894 deeptabular_layers is not None
850895 ), "deeptabular_layers must be passed if deeptabular_gradual=True"
851896 finetuner .finetune_gradual (
852- self . model . deeptabular ,
897+ deeptabular_component ,
853898 "deeptabular" ,
854899 loader ,
855900 deeptabular_max_lr ,
@@ -858,16 +903,16 @@ def _do_finetune(
858903 )
859904 else :
860905 finetuner .finetune_all (
861- self . model . deeptabular , "deeptabular" , loader , n_epochs , max_lr
906+ deeptabular_component , "deeptabular" , loader , n_epochs , max_lr
862907 )
863908
864- if self . model . deeptext :
909+ if deeptext_component :
865910 if deeptext_gradual :
866911 assert (
867912 deeptext_layers is not None
868913 ), "deeptext_layers must be passed if deeptext_gradual=True"
869914 finetuner .finetune_gradual (
870- self . model . deeptext ,
915+ deeptext_component ,
871916 "deeptext" ,
872917 loader ,
873918 deeptext_max_lr ,
@@ -876,16 +921,16 @@ def _do_finetune(
876921 )
877922 else :
878923 finetuner .finetune_all (
879- self . model . deeptext , "deeptext" , loader , n_epochs , max_lr
924+ deeptext_component , "deeptext" , loader , n_epochs , max_lr
880925 )
881926
882- if self . model . deepimage :
927+ if deepimage_component :
883928 if deepimage_gradual :
884929 assert (
885930 deepimage_layers is not None
886931 ), "deepimage_layers must be passed if deepimage_gradual=True"
887932 finetuner .finetune_gradual (
888- self . model . deepimage ,
933+ deepimage_component ,
889934 "deepimage" ,
890935 loader ,
891936 deepimage_max_lr ,
@@ -894,7 +939,7 @@ def _do_finetune(
894939 )
895940 else :
896941 finetuner .finetune_all (
897- self . model . deepimage , "deepimage" , loader , n_epochs , max_lr
942+ deepimage_component , "deepimage" , loader , n_epochs , max_lr
898943 )
899944
900945 def _train_epoch (
@@ -944,7 +989,7 @@ def _train_step(
944989
945990 y_pred = self .model (X )
946991
947- if self .model . is_tabnet :
992+ if self .is_model_tabnet :
948993 loss = self .loss_fn (y_pred [0 ], y ) - self .lambda_sparse * y_pred [1 ]
949994 score = self ._get_score (y_pred [0 ], y , is_train = True )
950995 else :
@@ -1008,7 +1053,7 @@ def _eval_step(
10081053 y = to_device (y , self .device )
10091054
10101055 y_pred = self .model (X )
1011- if self .model . is_tabnet :
1056+ if self .is_model_tabnet :
10121057 loss = self .loss_fn (y_pred [0 ], y ) - self .lambda_sparse * y_pred [1 ]
10131058 score = self ._get_score (y_pred [0 ], y , is_train = False )
10141059 else :
@@ -1119,7 +1164,7 @@ def _predict( # type: ignore[override, return] # noqa: C901
11191164 X [k ] = to_device (v , self .device )
11201165 preds = (
11211166 self .model (X )
1122- if not self .model . is_tabnet
1167+ if not self .is_model_tabnet
11231168 else self .model (X )[0 ]
11241169 )
11251170 if self .method == "binary" :
@@ -1170,6 +1215,7 @@ def _extract_kwargs(kwargs):
11701215 "prefetch_factor" ,
11711216 "persistent_workers" ,
11721217 "oversample_mul" ,
1218+ "pin_memory" ,
11731219 ]
11741220 finetune_params = [
11751221 "n_epochs" ,
0 commit comments