|
1 | | -from typing import Dict, List |
| 1 | +from typing import Any, Dict, List |
2 | 2 |
|
3 | 3 | import torch.nn as nn |
4 | 4 |
|
5 | 5 |
|
6 | 6 | def adjust_optim_params( |
7 | | - model: nn.Module, |
8 | | - encoder_lr: float, |
9 | | - encoder_wd: float, |
10 | | - decoder_lr: float, |
11 | | - decoder_wd: float, |
12 | | - remove_bias_wd: bool = True, |
| 7 | + model: nn.Module, optim_params: Dict[str, Dict[str, Any]] |
13 | 8 | ) -> List[Dict[str, Dict]]: |
14 | 9 | """Adjust the learning parameters for optimizer. |
15 | 10 |
|
16 | | - 1. Adjust learning rate and weight decay in the pre-trained |
17 | | - encoder and decoders. |
18 | | - 2. Remove weight decay from bias terms to reduce overfitting. |
19 | | -
|
20 | | - "Bag of Tricks for Image Classification with Convolutional Neural Networks" |
21 | | - - https://arxiv.org/pdf/1812.01187 |
22 | | -
|
23 | 11 | Parameters |
24 | 12 | ---------- |
25 | 13 | model : nn.Module |
26 | 14 | The encoder-decoder segmentation model. |
27 | | - encoder_lr : float |
28 | | - Learning rate of the model encoder. |
29 | | - encoder_wd : float |
30 | | - Weight decay for the model encoder. |
31 | | - decoder_lr : float |
32 | | - Learning rate of the model decoder. |
33 | | - decoder_wd : float |
34 | | - Weight decay for the model decoder. |
35 | | - remove_bias_wd : bool, default=True |
36 | | - If True, the weight decay from the bias terms is removed from the model |
37 | | - params. Ignored if `remove_wd`=True. |
| 15 | + optim_params : Dict[str, Dict[str, Any]] |
| 16 | + optim paramas like learning rates, weight decays etc for diff parts of |
| 17 | + the network. E.g. |
| 18 | + {"encoder": {"weight_decay: 0.1, "lr":0.1}, "sem": {"lr": 0.1}} |
38 | 19 |
|
39 | 20 | Returns |
40 | 21 | ------- |
41 | 22 | List[Dict[str, Dict]]: |
42 | 23 | a list of kwargs (str, Dict pairs) containing the model params. |
43 | 24 | """ |
44 | 25 | params = list(model.named_parameters()) |
45 | | - encoder_params = {"encoder": {"lr": encoder_lr, "weight_decay": encoder_wd}} |
46 | | - decoder_params = {"decoder": {"lr": decoder_lr, "weight_decay": decoder_wd}} |
47 | 26 |
|
48 | 27 | adjust_params = [] |
49 | 28 | for name, parameters in params: |
50 | 29 | opts = {} |
51 | | - for enc, enc_opts in encoder_params.items(): |
52 | | - if enc in name: |
53 | | - for key, item in enc_opts.items(): |
54 | | - opts[key] = item |
55 | 30 |
|
56 | | - for dec, dec_opts in decoder_params.items(): |
57 | | - if dec in name: |
58 | | - for key, item in dec_opts.items(): |
| 31 | + for block, block_params in optim_params.items(): |
| 32 | + if block in name: |
| 33 | + for key, item in block_params.items(): |
59 | 34 | opts[key] = item |
60 | 35 |
|
61 | | - if remove_bias_wd: |
62 | | - if name.endswith("bias"): |
63 | | - opts["weight_decay"] = 0.0 |
64 | | - |
65 | 36 | adjust_params.append({"params": parameters, **opts}) |
66 | 37 |
|
67 | 38 | return adjust_params |
0 commit comments