|
1 | 1 | import os |
2 | 2 | import sys |
| 3 | +import json |
3 | 4 | import warnings |
4 | 5 | from abc import ABC, abstractmethod |
| 6 | +from pathlib import Path |
5 | 7 |
|
6 | 8 | import numpy as np |
7 | 9 | import torch |
|
31 | 33 | from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor |
32 | 34 |
|
33 | 35 |
|
| 36 | +# There is quite a lot of code repetition between the |
| 37 | +# BaseContrastiveDenoisingTrainer and the BaseEncoderDecoderTrainer. Given |
| 38 | +# how differently they are instantiated I am happy to tolerate this |
| 39 | +# repetition. However, if the code base grows, it might be worth refactoring |
| 40 | +# this code |
34 | 41 | class BaseContrastiveDenoisingTrainer(ABC): |
35 | 42 | def __init__( |
36 | 43 | self, |
@@ -96,45 +103,82 @@ def pretrain( |
96 | 103 | ): |
97 | 104 | raise NotImplementedError("Trainer.pretrain method not implemented") |
98 | 105 |
|
99 | | - @abstractmethod |
100 | 106 | def save( |
101 | 107 | self, |
102 | 108 | path: str, |
103 | 109 | save_state_dict: bool, |
| 110 | + save_optimizer: bool, |
104 | 111 | model_filename: str, |
105 | 112 | ): |
106 | | - raise NotImplementedError("Trainer.save method not implemented") |
| 113 | + r"""Saves the model, training and evaluation history (if any) to disk |
107 | 114 |
|
108 | | - def _set_loss_fn(self, **kwargs): |
109 | | - if self.loss_type in ["contrastive", "both"]: |
110 | | - temperature = kwargs.get("temperature", 0.1) |
111 | | - reduction = kwargs.get("reduction", "mean") |
112 | | - self.contrastive_loss = InfoNCELoss(temperature, reduction) |
| 115 | + Parameters |
| 116 | + ---------- |
| 117 | + path: str |
| 118 | + path to the directory where the model and the feature importance |
| 119 | + attribute will be saved. |
| 120 | + save_state_dict: bool, default = False |
| 121 | + Boolean indicating whether to save directly the model or the |
| 122 | + model's state dictionary |
| 123 | + save_optimizer: bool, default = False |
| 124 | + Boolean indicating whether to save the optimizer or not |
| 125 | + model_filename: str, Optional, default = "ed_model.pt" |
| 126 | + filename where the model weights will be store |
| 127 | + """ |
113 | 128 |
|
114 | | - if self.loss_type in ["denoising", "both"]: |
115 | | - lambda_cat = kwargs.get("lambda_cat", 1.0) |
116 | | - lambda_cont = kwargs.get("lambda_cont", 1.0) |
117 | | - reduction = kwargs.get("reduction", "mean") |
118 | | - self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) |
| 129 | + self._save_history(path) |
119 | 130 |
|
120 | | - def _compute_loss( |
121 | | - self, |
122 | | - g_projs: Optional[Tuple[Tensor, Tensor]], |
123 | | - x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], |
124 | | - x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], |
125 | | - ) -> Tensor: |
126 | | - contrastive_loss = ( |
127 | | - self.contrastive_loss(g_projs) |
128 | | - if self.loss_type in ["contrastive", "both"] |
129 | | - else torch.tensor(0.0) |
| 131 | + self._save_model_and_optimizer( |
| 132 | + path, save_state_dict, save_optimizer, model_filename |
130 | 133 | ) |
131 | | - denoising_loss = ( |
132 | | - self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) |
133 | | - if self.loss_type in ["denoising", "both"] |
134 | | - else torch.tensor(0.0) |
| 134 | + |
| 135 | + def _save_history(self, path: str): |
| 136 | + # 'history' here refers to both, the training/evaluation history and |
| 137 | + # the lr history |
| 138 | + save_dir = Path(path) |
| 139 | + history_dir = save_dir / "history" |
| 140 | + history_dir.mkdir(exist_ok=True, parents=True) |
| 141 | + |
| 142 | + # the trainer is run with the History Callback by default |
| 143 | + with open(history_dir / "train_eval_history.json", "w") as teh: |
| 144 | + json.dump(self.history, teh) # type: ignore[attr-defined] |
| 145 | + |
| 146 | + has_lr_history = any( |
| 147 | + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] |
135 | 148 | ) |
| 149 | + if self.lr_scheduler is not None and has_lr_history: |
| 150 | + with open(history_dir / "lr_history.json", "w") as lrh: |
| 151 | + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] |
136 | 152 |
|
137 | | - return contrastive_loss + denoising_loss |
| 153 | + def _save_model_and_optimizer( |
| 154 | + self, |
| 155 | + path: str, |
| 156 | + save_state_dict: bool, |
| 157 | + save_optimizer: bool, |
| 158 | + model_filename: str, |
| 159 | + ): |
| 160 | + |
| 161 | + model_path = Path(path) / model_filename |
| 162 | + if save_state_dict and save_optimizer: |
| 163 | + torch.save( |
| 164 | + { |
| 165 | + "model_state_dict": self.cd_model.state_dict(), |
| 166 | + "optimizer_state_dict": self.optimizer.state_dict(), |
| 167 | + }, |
| 168 | + model_path, |
| 169 | + ) |
| 170 | + elif save_state_dict and not save_optimizer: |
| 171 | + torch.save(self.cd_model.state_dict(), model_path) |
| 172 | + elif not save_state_dict and save_optimizer: |
| 173 | + torch.save( |
| 174 | + { |
| 175 | + "model": self.cd_model, |
| 176 | + "optimizer": self.optimizer, # this can be a MultipleOptimizer |
| 177 | + }, |
| 178 | + model_path, |
| 179 | + ) |
| 180 | + else: |
| 181 | + torch.save(self.cd_model, model_path) |
138 | 182 |
|
139 | 183 | def _set_reduce_on_plateau_criterion( |
140 | 184 | self, lr_scheduler, reducelronplateau_criterion |
@@ -233,6 +277,37 @@ def _set_device_and_num_workers(**kwargs): |
233 | 277 | num_workers = kwargs.get("num_workers", default_num_workers) |
234 | 278 | return device, num_workers |
235 | 279 |
|
| 280 | + def _set_loss_fn(self, **kwargs): |
| 281 | + if self.loss_type in ["contrastive", "both"]: |
| 282 | + temperature = kwargs.get("temperature", 0.1) |
| 283 | + reduction = kwargs.get("reduction", "mean") |
| 284 | + self.contrastive_loss = InfoNCELoss(temperature, reduction) |
| 285 | + |
| 286 | + if self.loss_type in ["denoising", "both"]: |
| 287 | + lambda_cat = kwargs.get("lambda_cat", 1.0) |
| 288 | + lambda_cont = kwargs.get("lambda_cont", 1.0) |
| 289 | + reduction = kwargs.get("reduction", "mean") |
| 290 | + self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) |
| 291 | + |
| 292 | + def _compute_loss( |
| 293 | + self, |
| 294 | + g_projs: Optional[Tuple[Tensor, Tensor]], |
| 295 | + x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], |
| 296 | + x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], |
| 297 | + ) -> Tensor: |
| 298 | + contrastive_loss = ( |
| 299 | + self.contrastive_loss(g_projs) |
| 300 | + if self.loss_type in ["contrastive", "both"] |
| 301 | + else torch.tensor(0.0) |
| 302 | + ) |
| 303 | + denoising_loss = ( |
| 304 | + self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) |
| 305 | + if self.loss_type in ["denoising", "both"] |
| 306 | + else torch.tensor(0.0) |
| 307 | + ) |
| 308 | + |
| 309 | + return contrastive_loss + denoising_loss |
| 310 | + |
236 | 311 | @staticmethod |
237 | 312 | def _check_model_is_supported(model: ModelWithAttention): |
238 | 313 | if model.__class__.__name__ == "TabPerceiver": |
|
0 commit comments