|
18 | 18 | from pytabkit.models.nn_models import rtdl_num_embeddings |
19 | 19 | from pytabkit.models.nn_models.base import Fitter |
20 | 20 | from pytabkit.models.nn_models.models import PreprocessingFactory |
21 | | -from pytabkit.models.nn_models.tabm import Model |
| 21 | +from pytabkit.models.nn_models.tabm import Model, make_parameter_groups |
22 | 22 | from pytabkit.models.training.logging import Logger |
23 | 23 |
|
24 | 24 |
|
@@ -56,6 +56,8 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources: |
56 | 56 | allow_amp = self.config.get('allow_amp', False) |
57 | 57 | n_blocks = self.config.get('n_blocks', 'auto') |
58 | 58 | num_emb_n_bins = self.config.get('num_emb_n_bins', 48) |
| 59 | + # set default to True for backward compatibility |
| 60 | + share_training_batches = self.config.get("share_training_batches", False) |
59 | 61 |
|
60 | 62 | weight_decay = self.config.get('weight_decay', 0.0) |
61 | 63 | gradient_clipping_norm = self.config.get('gradient_clipping_norm', None) |
@@ -180,8 +182,9 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources: |
180 | 182 | ), |
181 | 183 | arch_type=arch_type, |
182 | 184 | k=tabm_k, |
| 185 | + share_training_batches=share_training_batches, |
183 | 186 | ).to(device) |
184 | | - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| 187 | + optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=lr, weight_decay=weight_decay) |
185 | 188 |
|
186 | 189 |
|
187 | 190 | if compile_model: |
@@ -210,8 +213,11 @@ def loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
210 | 213 | # TabM produces k predictions per object. Each of them must be trained separately. |
211 | 214 | # (regression) y_pred.shape == (batch_size, k) |
212 | 215 | # (classification) y_pred.shape == (batch_size, k, n_classes) |
213 | | - k = y_pred.shape[-1 if task_type == 'regression' else -2] |
214 | | - return base_loss_fn(y_pred.flatten(0, 1), y_true.repeat_interleave(k)) |
| 216 | + k = y_pred.shape[1] |
| 217 | + return base_loss_fn( |
| 218 | + y_pred.flatten(0, 1), |
| 219 | + y_true.repeat_interleave(k) if model.share_training_batches else y_true.squeeze(-1), |
| 220 | + ) |
215 | 221 |
|
216 | 222 | @evaluation_mode() |
217 | 223 | def evaluate(part: str) -> float: |
@@ -270,17 +276,22 @@ def evaluate(part: str) -> float: |
270 | 276 | if self.config.get('verbosity', 0) >= 1: |
271 | 277 | from tqdm.std import tqdm |
272 | 278 | else: |
273 | | - tqdm = lambda arr, desc, total: arr |
| 279 | + tqdm = lambda arr, desc: arr |
274 | 280 | except ImportError: |
275 | | - tqdm = lambda arr, desc, total: arr |
| 281 | + tqdm = lambda arr, desc: arr |
276 | 282 |
|
277 | 283 | logger.log(1, '-' * 88 + '\n') |
278 | 284 | for epoch in range(n_epochs): |
279 | | - for batch_idx in tqdm( |
280 | | - torch.randperm(len(data['train']['y']), device=device).split(batch_size), |
281 | | - desc=f'Epoch {epoch}', |
282 | | - total=epoch_size, |
283 | | - ): |
| 285 | + batches = ( |
| 286 | + torch.randperm(n_train, device=device).split(batch_size) |
| 287 | + if model.share_training_batches |
| 288 | + else [ |
| 289 | + x.transpose(0, 1).flatten() |
| 290 | + for x in torch.rand((model.k, n_train), device=device).argsort(dim=1).split(batch_size, dim=1) |
| 291 | + ] |
| 292 | + ) |
| 293 | + |
| 294 | + for batch_idx in tqdm(batches, desc=f"Epoch {epoch}"): |
284 | 295 | model.train() |
285 | 296 | optimizer.zero_grad() |
286 | 297 | loss = loss_fn(apply_model('train', batch_idx), Y_train[batch_idx]) |
|
0 commit comments