Skip to content

Commit 0e7ca8a

Browse files
committed
refactor: head_lr
1 parent 584057c commit 0e7ca8a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pytorch_optimizer/experimental/deberta_v3_lr_scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def deberta_v3_large_lr_scheduler(
99
model: nn.Module,
1010
head_param_start: int = 390,
1111
base_lr: float = 2e-5,
12-
last_lr: Optional[float] = None,
12+
head_lr: Optional[float] = None,
1313
wd: float = 1e-2,
1414
) -> PARAMETERS:
1515
"""DeBERTa-v3 large layer-wise lr scheduler
@@ -18,7 +18,7 @@ def deberta_v3_large_lr_scheduler(
1818
:param model: nn.Module. model. based on Huggingface Transformers.
1919
:param head_param_start: int. where the backbone ends (head starts)
2020
:param base_lr: float. base lr
21-
:param last_lr: float. last lr
21+
:param head_lr: float. head_lr
2222
:param wd: float. weight decay
2323
"""
2424
named_parameters = list(model.named_parameters())
@@ -29,8 +29,8 @@ def deberta_v3_large_lr_scheduler(
2929
regressor_group = [params for (_, params) in regressor_parameters]
3030

3131
parameters = []
32-
if last_lr is not None:
33-
parameters.append({'params': regressor_group, 'lr': last_lr})
32+
if head_lr is not None:
33+
parameters.append({'params': regressor_group, 'lr': head_lr})
3434
else:
3535
parameters.append({'params': regressor_group})
3636

0 commit comments

Comments
 (0)