Skip to content

Commit deb4f2e

Browse files
authored
Merge pull request #220 from jrzaurin/save_opt
Option to save Optimizer in the `save` method
2 parents 8057360 + 9d73a88 commit deb4f2e

File tree

12 files changed

+541
-179
lines changed

12 files changed

+541
-179
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
strategy:
3838
fail-fast: true
3939
matrix:
40-
python-version: ["3.8", "3.9", "3.10", "3.11"]
40+
python-version: ["3.9", "3.10", "3.11"]
4141
steps:
4242
- uses: actions/checkout@v4
4343
- name: Set up Python ${{ matrix.python-version }}

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.6.1
1+
1.6.2

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
import sys
3+
import json
34
import warnings
45
from abc import ABC, abstractmethod
6+
from pathlib import Path
57

68
import numpy as np
79
import torch
@@ -31,6 +33,11 @@
3133
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor
3234

3335

36+
# There is quite a lot of code repetition between the
37+
# BaseContrastiveDenoisingTrainer and the BaseEncoderDecoderTrainer. Given
38+
# how differently they are instantiated I am happy to tolerate this
39+
# repetition. However, if the code base grows, it might be worth refactoring
40+
# this code
3441
class BaseContrastiveDenoisingTrainer(ABC):
3542
def __init__(
3643
self,
@@ -96,45 +103,82 @@ def pretrain(
96103
):
97104
raise NotImplementedError("Trainer.pretrain method not implemented")
98105

99-
@abstractmethod
100106
def save(
101107
self,
102108
path: str,
103109
save_state_dict: bool,
110+
save_optimizer: bool,
104111
model_filename: str,
105112
):
106-
raise NotImplementedError("Trainer.save method not implemented")
113+
r"""Saves the model, training and evaluation history (if any) to disk
107114
108-
def _set_loss_fn(self, **kwargs):
109-
if self.loss_type in ["contrastive", "both"]:
110-
temperature = kwargs.get("temperature", 0.1)
111-
reduction = kwargs.get("reduction", "mean")
112-
self.contrastive_loss = InfoNCELoss(temperature, reduction)
115+
Parameters
116+
----------
117+
path: str
118+
path to the directory where the model and the feature importance
119+
attribute will be saved.
120+
save_state_dict: bool, default = False
121+
Boolean indicating whether to save directly the model or the
122+
model's state dictionary
123+
save_optimizer: bool, default = False
124+
Boolean indicating whether to save the optimizer or not
125+
model_filename: str, Optional, default = "ed_model.pt"
126+
filename where the model weights will be store
127+
"""
113128

114-
if self.loss_type in ["denoising", "both"]:
115-
lambda_cat = kwargs.get("lambda_cat", 1.0)
116-
lambda_cont = kwargs.get("lambda_cont", 1.0)
117-
reduction = kwargs.get("reduction", "mean")
118-
self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction)
129+
self._save_history(path)
119130

120-
def _compute_loss(
121-
self,
122-
g_projs: Optional[Tuple[Tensor, Tensor]],
123-
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]],
124-
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]],
125-
) -> Tensor:
126-
contrastive_loss = (
127-
self.contrastive_loss(g_projs)
128-
if self.loss_type in ["contrastive", "both"]
129-
else torch.tensor(0.0)
131+
self._save_model_and_optimizer(
132+
path, save_state_dict, save_optimizer, model_filename
130133
)
131-
denoising_loss = (
132-
self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_)
133-
if self.loss_type in ["denoising", "both"]
134-
else torch.tensor(0.0)
134+
135+
def _save_history(self, path: str):
136+
# 'history' here refers to both, the training/evaluation history and
137+
# the lr history
138+
save_dir = Path(path)
139+
history_dir = save_dir / "history"
140+
history_dir.mkdir(exist_ok=True, parents=True)
141+
142+
# the trainer is run with the History Callback by default
143+
with open(history_dir / "train_eval_history.json", "w") as teh:
144+
json.dump(self.history, teh) # type: ignore[attr-defined]
145+
146+
has_lr_history = any(
147+
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
135148
)
149+
if self.lr_scheduler is not None and has_lr_history:
150+
with open(history_dir / "lr_history.json", "w") as lrh:
151+
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]
136152

137-
return contrastive_loss + denoising_loss
153+
def _save_model_and_optimizer(
154+
self,
155+
path: str,
156+
save_state_dict: bool,
157+
save_optimizer: bool,
158+
model_filename: str,
159+
):
160+
161+
model_path = Path(path) / model_filename
162+
if save_state_dict and save_optimizer:
163+
torch.save(
164+
{
165+
"model_state_dict": self.cd_model.state_dict(),
166+
"optimizer_state_dict": self.optimizer.state_dict(),
167+
},
168+
model_path,
169+
)
170+
elif save_state_dict and not save_optimizer:
171+
torch.save(self.cd_model.state_dict(), model_path)
172+
elif not save_state_dict and save_optimizer:
173+
torch.save(
174+
{
175+
"model": self.cd_model,
176+
"optimizer": self.optimizer, # this can be a MultipleOptimizer
177+
},
178+
model_path,
179+
)
180+
else:
181+
torch.save(self.cd_model, model_path)
138182

139183
def _set_reduce_on_plateau_criterion(
140184
self, lr_scheduler, reducelronplateau_criterion
@@ -233,6 +277,37 @@ def _set_device_and_num_workers(**kwargs):
233277
num_workers = kwargs.get("num_workers", default_num_workers)
234278
return device, num_workers
235279

280+
def _set_loss_fn(self, **kwargs):
281+
if self.loss_type in ["contrastive", "both"]:
282+
temperature = kwargs.get("temperature", 0.1)
283+
reduction = kwargs.get("reduction", "mean")
284+
self.contrastive_loss = InfoNCELoss(temperature, reduction)
285+
286+
if self.loss_type in ["denoising", "both"]:
287+
lambda_cat = kwargs.get("lambda_cat", 1.0)
288+
lambda_cont = kwargs.get("lambda_cont", 1.0)
289+
reduction = kwargs.get("reduction", "mean")
290+
self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction)
291+
292+
def _compute_loss(
293+
self,
294+
g_projs: Optional[Tuple[Tensor, Tensor]],
295+
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]],
296+
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]],
297+
) -> Tensor:
298+
contrastive_loss = (
299+
self.contrastive_loss(g_projs)
300+
if self.loss_type in ["contrastive", "both"]
301+
else torch.tensor(0.0)
302+
)
303+
denoising_loss = (
304+
self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_)
305+
if self.loss_type in ["denoising", "both"]
306+
else torch.tensor(0.0)
307+
)
308+
309+
return contrastive_loss + denoising_loss
310+
236311
@staticmethod
237312
def _check_model_is_supported(model: ModelWithAttention):
238313
if model.__class__.__name__ == "TabPerceiver":

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
import sys
3+
import json
34
import warnings
45
from abc import ABC, abstractmethod
6+
from pathlib import Path
57

68
import numpy as np
79
import torch
@@ -66,22 +68,90 @@ def __init__(
6668
def pretrain(
6769
self,
6870
X_tab: np.ndarray,
69-
X_val: Optional[np.ndarray],
71+
X_tab_val: Optional[np.ndarray],
7072
val_split: Optional[float],
7173
validation_freq: int,
7274
n_epochs: int,
7375
batch_size: int,
7476
):
7577
raise NotImplementedError("Trainer.pretrain method not implemented")
7678

77-
@abstractmethod
7879
def save(
7980
self,
8081
path: str,
8182
save_state_dict: bool,
83+
save_optimizer: bool,
8284
model_filename: str,
8385
):
84-
raise NotImplementedError("Trainer.save method not implemented")
86+
r"""Saves the model, training and evaluation history (if any) to disk
87+
88+
Parameters
89+
----------
90+
path: str
91+
path to the directory where the model and the feature importance
92+
attribute will be saved.
93+
save_state_dict: bool, default = False
94+
Boolean indicating whether to save directly the model or the
95+
model's state dictionary
96+
save_optimizer: bool, default = False
97+
Boolean indicating whether to save the optimizer or not
98+
model_filename: str, Optional, default = "ed_model.pt"
99+
filename where the model weights will be store
100+
"""
101+
102+
self._save_history(path)
103+
104+
self._save_model_and_optimizer(
105+
path, save_state_dict, save_optimizer, model_filename
106+
)
107+
108+
def _save_history(self, path: str):
109+
# 'history' here refers to both, the training/evaluation history and
110+
# the lr history
111+
save_dir = Path(path)
112+
history_dir = save_dir / "history"
113+
history_dir.mkdir(exist_ok=True, parents=True)
114+
115+
# the trainer is run with the History Callback by default
116+
with open(history_dir / "train_eval_history.json", "w") as teh:
117+
json.dump(self.history, teh) # type: ignore[attr-defined]
118+
119+
has_lr_history = any(
120+
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
121+
)
122+
if self.lr_scheduler is not None and has_lr_history:
123+
with open(history_dir / "lr_history.json", "w") as lrh:
124+
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]
125+
126+
def _save_model_and_optimizer(
127+
self,
128+
path: str,
129+
save_state_dict: bool,
130+
save_optimizer: bool,
131+
model_filename: str,
132+
):
133+
134+
model_path = Path(path) / model_filename
135+
if save_state_dict and save_optimizer:
136+
torch.save(
137+
{
138+
"model_state_dict": self.ed_model.state_dict(),
139+
"optimizer_state_dict": self.optimizer.state_dict(),
140+
},
141+
model_path,
142+
)
143+
elif save_state_dict and not save_optimizer:
144+
torch.save(self.ed_model.state_dict(), model_path)
145+
elif not save_state_dict and save_optimizer:
146+
torch.save(
147+
{
148+
"model": self.ed_model,
149+
"optimizer": self.optimizer, # this can be a MultipleOptimizer
150+
},
151+
model_path,
152+
)
153+
else:
154+
torch.save(self.ed_model, model_path)
85155

86156
def _set_reduce_on_plateau_criterion(
87157
self, lr_scheduler, reducelronplateau_criterion

pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import json
2-
from pathlib import Path
3-
41
import numpy as np
52
import torch
63
from tqdm import trange
@@ -259,46 +256,6 @@ def fit(
259256
X_tab, X_tab_val, val_split, validation_freq, n_epochs, batch_size
260257
)
261258

262-
def save(
263-
self,
264-
path: str,
265-
save_state_dict: bool = False,
266-
model_filename: str = "cd_model.pt",
267-
):
268-
r"""Saves the model, training and evaluation history (if any) to disk
269-
270-
Parameters
271-
----------
272-
path: str
273-
path to the directory where the model and the feature importance
274-
attribute will be saved.
275-
save_state_dict: bool, default = False
276-
Boolean indicating whether to save directly the model or the
277-
model's state dictionary
278-
model_filename: str, Optional, default = "cd_model.pt"
279-
filename where the model weights will be store
280-
"""
281-
save_dir = Path(path)
282-
history_dir = save_dir / "history"
283-
history_dir.mkdir(exist_ok=True, parents=True)
284-
285-
# the trainer is run with the History Callback by default
286-
with open(history_dir / "train_eval_history.json", "w") as teh:
287-
json.dump(self.history, teh) # type: ignore[attr-defined]
288-
289-
has_lr_history = any(
290-
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
291-
)
292-
if self.lr_scheduler is not None and has_lr_history:
293-
with open(history_dir / "lr_history.json", "w") as lrh:
294-
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]
295-
296-
model_path = save_dir / model_filename
297-
if save_state_dict:
298-
torch.save(self.cd_model.state_dict(), model_path)
299-
else:
300-
torch.save(self.cd_model, model_path)
301-
302259
def _train_step(self, X_tab: Tensor, batch_idx: int) -> float:
303260
X = X_tab.to(self.device)
304261

@@ -337,7 +294,7 @@ def _train_eval_split(
337294
train_set = TensorDataset(torch.from_numpy(X))
338295
eval_set = TensorDataset(torch.from_numpy(X_tab_val))
339296
elif val_split is not None:
340-
X_tr, X_tab_val = train_test_split(
297+
X_tr, X_tab_val = train_test_split( # type: ignore
341298
X, test_size=val_split, random_state=self.seed
342299
)
343300
train_set = TensorDataset(torch.from_numpy(X_tr))

0 commit comments

Comments
 (0)