Skip to content

Commit b4be50b

Browse files
committed
refactor: deberta_v3_large_lr_scheduler
1 parent 419cf6c commit b4be50b

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

pytorch_optimizer/experimental/deberta_v3_lr_scheduler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,33 @@
55

66
def deberta_v3_large_lr_scheduler(
77
model: nn.Module,
8+
layer_low_threshold: int = 195,
9+
layer_middle_threshold: int = 323,
810
head_param_start: int = 390,
911
base_lr: float = 2e-5,
1012
head_lr: float = 1e-4,
1113
wd: float = 1e-2,
1214
) -> PARAMETERS:
13-
"""DeBERTa-v3 large layer-wise lr scheduler
14-
Reference : https://github.com/gilfernandes/commonlit.
15+
"""DeBERTa-v3 large layer-wise lr scheduler.
16+
17+
Reference : https://github.com/gilfernandes/commonlit.
1518
1619
:param model: nn.Module. model. based on Huggingface Transformers.
17-
:param head_param_start: int. where the backbone ends (head starts)
18-
:param base_lr: float. base lr
19-
:param head_lr: float. head_lr
20-
:param wd: float. weight decay
20+
:param layer_low_threshold: int. start of the 12 layers.
21+
:param layer_middle_threshold: int. end of the 24 layers.
22+
:param head_param_start: int. where the backbone ends (head starts).
23+
:param base_lr: float. base lr.
24+
:param head_lr: float. head_lr.
25+
:param wd: float. weight decay.
2126
"""
2227
named_parameters = list(model.named_parameters())
2328

2429
backbone_parameters = named_parameters[:head_param_start]
25-
regressor_parameters = named_parameters[head_param_start:]
26-
27-
regressor_group = [params for (_, params) in regressor_parameters]
30+
head_parameters = named_parameters[head_param_start:]
2831

29-
parameters = [{'params': regressor_group, 'lr': head_lr}]
32+
head_group = [params for (_, params) in head_parameters]
3033

31-
layer_low_threshold: int = 195 # start of the 12 layers
32-
layer_middle_threshold: int = 323 # end of the 24 layers
34+
parameters = [{'params': head_group, 'lr': head_lr}]
3335

3436
for layer_num, (name, params) in enumerate(backbone_parameters):
3537
weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd

0 commit comments

Comments
 (0)