Skip to content

Commit c7c3104

Browse files
authored
Merge pull request #250 from jrzaurin/multi-gpu
Added Support for multiple GPUs
2 parents 268dad9 + df9d6b3 commit c7c3104

File tree

6 files changed

+293
-22
lines changed

6 files changed

+293
-22
lines changed

pytorch_widedeep/models/wide_deep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,13 @@ def _forward_wide(self, X: Dict[str, Union[Tensor, List[Tensor]]]) -> Tensor:
297297
first_model_mode = list(X.keys())[0]
298298
if isinstance(X[first_model_mode], list):
299299
batch_size = X[first_model_mode][0].size(0)
300+
# Get device from input tensor
301+
device = X[first_model_mode][0].device
300302
else:
301303
batch_size = X[first_model_mode].size(0) # type: ignore[union-attr]
302-
out = torch.zeros(batch_size, self.pred_dim).to(self.wd_device)
304+
# Get device from input tensor
305+
device = X[first_model_mode].device # type: ignore[union-attr]
306+
out = torch.zeros(batch_size, self.pred_dim, device=device)
303307

304308
return out
305309

@@ -331,7 +335,9 @@ def _forward_deep(
331335
def _forward_deephead(
332336
self, X: Dict[str, Union[Tensor, List[Tensor]]], wide_out: Tensor
333337
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
334-
deepside = torch.FloatTensor().to(self.wd_device)
338+
# Get device from wide_out
339+
device = wide_out.device
340+
deepside = torch.FloatTensor().to(device)
335341

336342
if self.deeptabular is not None:
337343
if self.is_tabnet:

pytorch_widedeep/training/_base_trainer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,22 @@ def __init__(
7878
self.seed = seed
7979

8080
self.model = to_device_model(model, self.device)
81-
if self.model.is_tabnet:
81+
82+
self.is_model_tabnet = model.is_tabnet
83+
if self.is_model_tabnet:
8284
self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
85+
86+
# Simply we need this attribute
8387
self.model.wd_device = self.device
8488

89+
use_multi_gpu = kwargs.get("use_multi_gpu", False) and self.device.startswith(
90+
"cuda"
91+
)
92+
if use_multi_gpu and torch.cuda.device_count() > 1:
93+
if self.verbose:
94+
print(f"Using {torch.cuda.device_count()} GPUs for training")
95+
self.model = torch.nn.DataParallel(self.model)
96+
8597
self.objective = objective
8698
self.method: str = _ObjectiveToMethod.get(objective) # type: ignore
8799

@@ -444,6 +456,14 @@ def _set_device_and_num_workers(**kwargs) -> Tuple[str, int]:
444456
num_workers = kwargs.get("num_workers", default_num_workers)
445457
default_device = setup_device()
446458
device = kwargs.get("device", default_device)
459+
460+
# Check for multi-GPU setup
461+
use_cuda = device.startswith("cuda")
462+
use_multi_gpu = use_cuda and kwargs.get("use_multi_gpu", False)
463+
464+
if use_multi_gpu and torch.cuda.device_count() > 1:
465+
device = f"cuda:{torch.cuda.current_device()}"
466+
447467
return device, num_workers
448468

449469
def __repr__(self) -> str: # noqa: C901

pytorch_widedeep/training/trainer.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,21 @@ class Trainer(BaseTrainer):
160160
- **num_workers**: `int`<br/>
161161
number of workers to be used internally by the data loaders
162162
163+
- **use_multi_gpu**: `bool`<br/>
164+
If True, the model will be trained on multiple GPUs. This is
165+
only supported for the `deeptabular` component.
166+
167+
NOTE: this is an experimental feature and might not work as expected
168+
in some cases. In the particular case of the `Trainer` class, it has
169+
been extensively tested.
170+
163171
- **lambda_sparse**: `float`<br/>
164172
lambda sparse parameter in case the `deeptabular` component is `TabNet`
165173
166174
- **class_weight**: `List[float]`<br/>
167175
This is the `weight` or `pos_weight` parameter in
168176
`CrossEntropyLoss` and `BCEWithLogitsLoss`, depending on whether
177+
169178
- **reducelronplateau_criterion**: `str`
170179
This sets the criterion that will be used by the lr scheduler to
171180
take a step: One of _'loss'_ or _'metric'_. The ReduceLROnPlateau
@@ -834,22 +843,58 @@ def _do_finetune(
834843
r"""
835844
Simple wrap-up to individually fine-tune model components
836845
"""
837-
if self.model.deephead is not None:
846+
847+
if isinstance(self.model, torch.nn.DataParallel):
848+
wide_component = (
849+
torch.nn.DataParallel(self.model.module.wide)
850+
if self.model.module.wide
851+
else None
852+
)
853+
deeptabular_component = (
854+
torch.nn.DataParallel(self.model.module.deeptabular)
855+
if self.model.module.deeptabular
856+
else None
857+
)
858+
deeptext_component = (
859+
torch.nn.DataParallel(self.model.module.deeptext)
860+
if self.model.module.deeptext
861+
else None
862+
)
863+
deepimage_component = (
864+
torch.nn.DataParallel(self.model.module.deepimage)
865+
if self.model.module.deepimage
866+
else None
867+
)
868+
deephead_component = (
869+
torch.nn.DataParallel(self.model.module.deephead)
870+
if self.model.module.deephead
871+
else None
872+
)
873+
else:
874+
wide_component = self.model.wide if self.model.wide else None
875+
deeptabular_component = (
876+
self.model.deeptabular if self.model.deeptabular else None
877+
)
878+
deeptext_component = self.model.deeptext if self.model.deeptext else None
879+
deepimage_component = self.model.deepimage if self.model.deepimage else None
880+
deephead_component = self.model.deephead if self.model.deephead else None
881+
882+
if deephead_component is not None:
838883
raise ValueError(
839884
"Currently warming up is only supported without a fully connected 'DeepHead'"
840885
)
841886

842887
finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose) # type: ignore[arg-type]
843-
if self.model.wide:
844-
finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr)
888+
if wide_component:
889+
finetuner.finetune_all(wide_component, "wide", loader, n_epochs, max_lr)
845890

846-
if self.model.deeptabular:
891+
if deeptabular_component:
847892
if deeptabular_gradual:
848893
assert (
849894
deeptabular_layers is not None
850895
), "deeptabular_layers must be passed if deeptabular_gradual=True"
851896
finetuner.finetune_gradual(
852-
self.model.deeptabular,
897+
deeptabular_component,
853898
"deeptabular",
854899
loader,
855900
deeptabular_max_lr,
@@ -858,16 +903,16 @@ def _do_finetune(
858903
)
859904
else:
860905
finetuner.finetune_all(
861-
self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr
906+
deeptabular_component, "deeptabular", loader, n_epochs, max_lr
862907
)
863908

864-
if self.model.deeptext:
909+
if deeptext_component:
865910
if deeptext_gradual:
866911
assert (
867912
deeptext_layers is not None
868913
), "deeptext_layers must be passed if deeptext_gradual=True"
869914
finetuner.finetune_gradual(
870-
self.model.deeptext,
915+
deeptext_component,
871916
"deeptext",
872917
loader,
873918
deeptext_max_lr,
@@ -876,16 +921,16 @@ def _do_finetune(
876921
)
877922
else:
878923
finetuner.finetune_all(
879-
self.model.deeptext, "deeptext", loader, n_epochs, max_lr
924+
deeptext_component, "deeptext", loader, n_epochs, max_lr
880925
)
881926

882-
if self.model.deepimage:
927+
if deepimage_component:
883928
if deepimage_gradual:
884929
assert (
885930
deepimage_layers is not None
886931
), "deepimage_layers must be passed if deepimage_gradual=True"
887932
finetuner.finetune_gradual(
888-
self.model.deepimage,
933+
deepimage_component,
889934
"deepimage",
890935
loader,
891936
deepimage_max_lr,
@@ -894,7 +939,7 @@ def _do_finetune(
894939
)
895940
else:
896941
finetuner.finetune_all(
897-
self.model.deepimage, "deepimage", loader, n_epochs, max_lr
942+
deepimage_component, "deepimage", loader, n_epochs, max_lr
898943
)
899944

900945
def _train_epoch(
@@ -944,7 +989,7 @@ def _train_step(
944989

945990
y_pred = self.model(X)
946991

947-
if self.model.is_tabnet:
992+
if self.is_model_tabnet:
948993
loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
949994
score = self._get_score(y_pred[0], y, is_train=True)
950995
else:
@@ -1008,7 +1053,7 @@ def _eval_step(
10081053
y = to_device(y, self.device)
10091054

10101055
y_pred = self.model(X)
1011-
if self.model.is_tabnet:
1056+
if self.is_model_tabnet:
10121057
loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
10131058
score = self._get_score(y_pred[0], y, is_train=False)
10141059
else:
@@ -1119,7 +1164,7 @@ def _predict( # type: ignore[override, return] # noqa: C901
11191164
X[k] = to_device(v, self.device)
11201165
preds = (
11211166
self.model(X)
1122-
if not self.model.is_tabnet
1167+
if not self.is_model_tabnet
11231168
else self.model(X)[0]
11241169
)
11251170
if self.method == "binary":
@@ -1170,6 +1215,7 @@ def _extract_kwargs(kwargs):
11701215
"prefetch_factor",
11711216
"persistent_workers",
11721217
"oversample_mul",
1218+
"pin_memory",
11731219
]
11741220
finetune_params = [
11751221
"n_epochs",

pytorch_widedeep/training/trainer_from_folder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,17 @@ class TrainerFromFolder(Trainer):
144144
- **num_workers**: `int`<br/>
145145
number of workers to be used internally by the data loaders
146146
147-
- **lambda_sparse**: `float`<br/>
147+
- **use_multi_gpu**: `bool`<br/>
148+
If True, the model will be trained on multiple GPUs. This is
149+
only supported for the `deeptabular` component.
150+
151+
NOTE: this is an experimental feature and might not work as expected
152+
in some cases. While for the `Trainer` class, it has been extensively
153+
tested, for the `TrainerFromFolder` class, it has not been tested
154+
that thoroughly (in principle the `TrainerFromFolder` inherits from
155+
the `Trainer` class, so it should work).
156+
157+
- **lambda_sparse**: `float`<br/>
148158
lambda sparse parameter in case the `deeptabular` component is `TabNet`
149159
150160
- **class_weight**: `List[float]`<br/>

pytorch_widedeep/utils/general_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def setup_device() -> str:
88
if torch.cuda.is_available():
9-
return "cuda"
9+
return f"cuda:{torch.cuda.current_device()}"
1010
elif torch.backends.mps.is_available():
1111
return "mps"
1212
else:
@@ -24,11 +24,10 @@ def to_device_model(model, device: str): # noqa: C901
2424
# insistent transformation since it some cases overall approaches such as
2525
# model.to('mps') do not work
2626

27-
if device in ["cpu", "cuda"]:
27+
if device == "cpu" or (device.startswith("cuda") and torch.cuda.is_available()):
2828
return model.to(device)
2929

3030
if device == "mps":
31-
3231
try:
3332
return model.to(device)
3433
except (RuntimeError, TypeError):

0 commit comments

Comments
 (0)