Skip to content

Commit 4bc35dc

Browse files
committed
exclude some parameters from weight decay
1 parent bba0c07 commit 4bc35dc

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

mmlearn/tasks/contrastive_pretraining.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Contrastive pretraining task."""
22

3+
import inspect
34
import itertools
45
from dataclasses import dataclass
56
from functools import partial
@@ -502,16 +503,54 @@ def on_test_epoch_end(self) -> None:
502503
"""Compute and log epoch-level metrics at the end of the test epoch."""
503504
self._on_eval_epoch_end("test")
504505

505-
def configure_optimizers(self) -> OptimizerLRScheduler:
506+
def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912
506507
"""Configure the optimizer and learning rate scheduler."""
507508
if self.optimizer is None:
508509
rank_zero_warn(
509510
"Optimizer not provided. Training will continue without an optimizer. "
510511
"LR scheduler will not be used.",
511512
)
512513
return None
513-
# TODO: add mechanism to exclude certain parameters from weight decay
514-
optimizer = self.optimizer(self.parameters())
514+
515+
weight_decay: Optional[float] = self.optimizer.keywords.get(
516+
"weight_decay", None
517+
)
518+
if weight_decay is None: # try getting default value
519+
kw_param = inspect.signature(self.optimizer.func).parameters.get(
520+
"weight_decay"
521+
)
522+
if kw_param is not None and kw_param.default != inspect.Parameter.empty:
523+
weight_decay = kw_param.default
524+
525+
parameters = [param for param in self.parameters() if param.requires_grad]
526+
527+
if weight_decay is not None:
528+
decay_params = []
529+
no_decay_params = []
530+
531+
for param in self.parameters():
532+
if not param.requires_grad:
533+
continue
534+
535+
if param.ndim < 2: # includes all bias and normalization parameters
536+
no_decay_params.append(param)
537+
else:
538+
decay_params.append(param)
539+
540+
parameters = [
541+
{
542+
"params": decay_params,
543+
"weight_decay": weight_decay,
544+
"name": "weight_decay_params",
545+
},
546+
{
547+
"params": no_decay_params,
548+
"weight_decay": 0.0,
549+
"name": "no_weight_decay_params",
550+
},
551+
]
552+
553+
optimizer = self.optimizer(parameters)
515554
if not isinstance(optimizer, torch.optim.Optimizer):
516555
raise TypeError(
517556
"Expected optimizer to be an instance of `torch.optim.Optimizer`, "

0 commit comments

Comments
 (0)