|
5 | 5 |
|
6 | 6 | def deberta_v3_large_lr_scheduler( |
7 | 7 | model: nn.Module, |
| 8 | + layer_low_threshold: int = 195, |
| 9 | + layer_middle_threshold: int = 323, |
8 | 10 | head_param_start: int = 390, |
9 | 11 | base_lr: float = 2e-5, |
10 | 12 | head_lr: float = 1e-4, |
11 | 13 | wd: float = 1e-2, |
12 | 14 | ) -> PARAMETERS: |
13 | | - """DeBERTa-v3 large layer-wise lr scheduler |
14 | | - Reference : https://github.com/gilfernandes/commonlit. |
| 15 | + """DeBERTa-v3 large layer-wise lr scheduler. |
| 16 | +
|
| 17 | + Reference : https://github.com/gilfernandes/commonlit. |
15 | 18 |
|
16 | 19 | :param model: nn.Module. model. based on Huggingface Transformers. |
17 | | - :param head_param_start: int. where the backbone ends (head starts) |
18 | | - :param base_lr: float. base lr |
19 | | - :param head_lr: float. head_lr |
20 | | - :param wd: float. weight decay |
| 20 | + :param layer_low_threshold: int. start of the 12 layers. |
| 21 | + :param layer_middle_threshold: int. end of the 24 layers. |
| 22 | + :param head_param_start: int. where the backbone ends (head starts). |
| 23 | + :param base_lr: float. base lr. |
| 24 | + :param head_lr: float. head_lr. |
| 25 | + :param wd: float. weight decay. |
21 | 26 | """ |
22 | 27 | named_parameters = list(model.named_parameters()) |
23 | 28 |
|
24 | 29 | backbone_parameters = named_parameters[:head_param_start] |
25 | | - regressor_parameters = named_parameters[head_param_start:] |
26 | | - |
27 | | - regressor_group = [params for (_, params) in regressor_parameters] |
| 30 | + head_parameters = named_parameters[head_param_start:] |
28 | 31 |
|
29 | | - parameters = [{'params': regressor_group, 'lr': head_lr}] |
| 32 | + head_group = [params for (_, params) in head_parameters] |
30 | 33 |
|
31 | | - layer_low_threshold: int = 195 # start of the 12 layers |
32 | | - layer_middle_threshold: int = 323 # end of the 24 layers |
| 34 | + parameters = [{'params': head_group, 'lr': head_lr}] |
33 | 35 |
|
34 | 36 | for layer_num, (name, params) in enumerate(backbone_parameters): |
35 | 37 | weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd |
|
0 commit comments