Skip to content

Commit d54fe0a

Browse files
authored
Add partial wd
1 parent 1b50708 commit d54fe0a

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

main_training_mamba.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,35 @@ def main(**kwargs):
107107
model = torch.compile(model)
108108

109109
# Optimizer
110+
# optimizer = optim.AdamW(
111+
# model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
112+
# )
113+
params_with_decay = []
114+
params_without_decay = []
115+
for name, param in model.named_parameters():
116+
print(f'{name=}')
117+
if 'A_log' in name or 'D' in name or 'dt_bias' in name:
118+
params_without_decay.append(param)
119+
else:
120+
params_with_decay.append(param)
121+
122+
123+
print(f'{params_with_decay=}')
124+
print(f'{params_without_decay=}')
125+
110126
optimizer = optim.AdamW(
111-
model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
127+
[
128+
{
129+
"params": params_with_decay,
130+
"weight_decay": 0.1,
131+
},
132+
{
133+
"params": params_without_decay,
134+
"weight_decay": 0.,
135+
},
136+
],
137+
betas = (0.9, 0.95),
138+
lr = cfg.learning_rate, # cfg.learning_rate,
112139
)
113140

114141
# optionally load from checkpoint (when continue pretraining)

0 commit comments

Comments
 (0)