Skip to content

Commit cdab9e3

Browse files
committed
chore: more torch optimizers
1 parent 8ca5acf commit cdab9e3

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/backends/torch/torchsolver.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,28 @@ namespace dd
9797
_params, torch::optim::AdamOptions(_base_lr)
9898
.betas(std::make_tuple(_beta1, _beta2))
9999
.weight_decay(_weight_decay)));
100-
this->_logger->info("base_lr: {}", _base_lr);
100+
}
101+
else if (_solver_type == "ADAMW")
102+
{
103+
_optimizer
104+
= std::unique_ptr<torch::optim::Optimizer>(new torch::optim::AdamW(
105+
_params, torch::optim::AdamWOptions(_base_lr)
106+
.betas(std::make_tuple(_beta1, _beta2))
107+
.weight_decay(_weight_decay)));
101108
}
102109
else if (_solver_type == "RMSPROP")
103110
{
104111
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
105112
new torch::optim::RMSprop(
106113
_params, torch::optim::RMSpropOptions(_base_lr).weight_decay(
107114
_weight_decay)));
108-
this->_logger->info("base_lr: {}", _base_lr);
109115
}
110116
else if (_solver_type == "ADAGRAD")
111117
{
112118
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
113119
new torch::optim::Adagrad(
114120
_params, torch::optim::AdagradOptions(_base_lr).weight_decay(
115121
_weight_decay)));
116-
this->_logger->info("base_lr: {}", _base_lr);
117122
}
118123
else if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS")
119124
{
@@ -131,7 +136,6 @@ namespace dd
131136
.adamp(_adamp)
132137
.lsteps(_lsteps)
133138
.lalpha(_lalpha)));
134-
this->_logger->info("base_lr: {}", _base_lr);
135139
this->_logger->info("beta_1: {}", _beta1);
136140
this->_logger->info("beta_2: {}", _beta2);
137141
this->_logger->info("weight_decay: {}", _weight_decay);
@@ -162,7 +166,6 @@ namespace dd
162166
.lookahead(_lookahead)
163167
.lsteps(_lsteps)
164168
.lalpha(_lalpha)));
165-
this->_logger->info("base_lr: {}", _base_lr);
166169
this->_logger->info("momentum: {}", _momentum);
167170
this->_logger->info("weight_decay: {}", _weight_decay);
168171
this->_logger->info("lookahead: {}", _lookahead);
@@ -180,7 +183,6 @@ namespace dd
180183
_optimizer
181184
= std::unique_ptr<torch::optim::Optimizer>(new torch::optim::SGD(
182185
_params, torch::optim::SGDOptions(_base_lr)));
183-
this->_logger->info("base_lr: {}", _base_lr);
184186
}
185187
this->_logger->info("clip: {}", _clip);
186188
if (_clip)
@@ -199,6 +201,8 @@ namespace dd
199201
}
200202
if (_sam)
201203
this->_logger->info("using Sharpness Aware Minimization (SAM)");
204+
this->_logger->info("using optimizer " + _solver_type);
205+
this->_logger->info("base_lr: {}", _base_lr);
202206
}
203207

204208
void TorchSolver::sam_first_step()
@@ -417,6 +421,14 @@ namespace dd
417421
options.betas(std::make_tuple(_beta1, _beta2));
418422
options.weight_decay(_weight_decay);
419423
}
424+
else if (_solver_type == "ADAMW")
425+
{
426+
auto &options = static_cast<torch::optim::AdamWOptions &>(
427+
param_group.options());
428+
options.lr(_base_lr);
429+
options.betas(std::make_tuple(_beta1, _beta2));
430+
options.weight_decay(_weight_decay);
431+
}
420432
else if (_solver_type == "RMSPROP")
421433
{
422434
auto &options = static_cast<torch::optim::RMSpropOptions &>(

0 commit comments

Comments
 (0)