|
1 | 1 | """Contrastive pretraining task.""" |
2 | 2 |
|
| 3 | +import inspect |
3 | 4 | import itertools |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from functools import partial |
@@ -502,16 +503,54 @@ def on_test_epoch_end(self) -> None: |
502 | 503 | """Compute and log epoch-level metrics at the end of the test epoch.""" |
503 | 504 | self._on_eval_epoch_end("test") |
504 | 505 |
|
505 | | - def configure_optimizers(self) -> OptimizerLRScheduler: |
| 506 | + def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912 |
506 | 507 | """Configure the optimizer and learning rate scheduler.""" |
507 | 508 | if self.optimizer is None: |
508 | 509 | rank_zero_warn( |
509 | 510 | "Optimizer not provided. Training will continue without an optimizer. " |
510 | 511 | "LR scheduler will not be used.", |
511 | 512 | ) |
512 | 513 | return None |
513 | | - # TODO: add mechanism to exclude certain parameters from weight decay |
514 | | - optimizer = self.optimizer(self.parameters()) |
| 514 | + |
| 515 | + weight_decay: Optional[float] = self.optimizer.keywords.get( |
| 516 | + "weight_decay", None |
| 517 | + ) |
| 518 | + if weight_decay is None: # try getting default value |
| 519 | + kw_param = inspect.signature(self.optimizer.func).parameters.get( |
| 520 | + "weight_decay" |
| 521 | + ) |
| 522 | + if kw_param is not None and kw_param.default != inspect.Parameter.empty: |
| 523 | + weight_decay = kw_param.default |
| 524 | + |
| 525 | + parameters = [param for param in self.parameters() if param.requires_grad] |
| 526 | + |
| 527 | + if weight_decay is not None: |
| 528 | + decay_params = [] |
| 529 | + no_decay_params = [] |
| 530 | + |
| 531 | + for param in self.parameters(): |
| 532 | + if not param.requires_grad: |
| 533 | + continue |
| 534 | + |
| 535 | + if param.ndim < 2: # includes all bias and normalization parameters |
| 536 | + no_decay_params.append(param) |
| 537 | + else: |
| 538 | + decay_params.append(param) |
| 539 | + |
| 540 | + parameters = [ |
| 541 | + { |
| 542 | + "params": decay_params, |
| 543 | + "weight_decay": weight_decay, |
| 544 | + "name": "weight_decay_params", |
| 545 | + }, |
| 546 | + { |
| 547 | + "params": no_decay_params, |
| 548 | + "weight_decay": 0.0, |
| 549 | + "name": "no_weight_decay_params", |
| 550 | + }, |
| 551 | + ] |
| 552 | + |
| 553 | + optimizer = self.optimizer(parameters) |
515 | 554 | if not isinstance(optimizer, torch.optim.Optimizer): |
516 | 555 | raise TypeError( |
517 | 556 | "Expected optimizer to be an instance of `torch.optim.Optimizer`, " |
|
0 commit comments