@@ -97,23 +97,28 @@ namespace dd
9797 _params, torch::optim::AdamOptions (_base_lr)
9898 .betas (std::make_tuple (_beta1, _beta2))
9999 .weight_decay (_weight_decay)));
100- this ->_logger ->info (" base_lr: {}" , _base_lr);
100+ }
101+ else if (_solver_type == " ADAMW" )
102+ {
103+ _optimizer
104+ = std::unique_ptr<torch::optim::Optimizer>(new torch::optim::AdamW (
105+ _params, torch::optim::AdamWOptions (_base_lr)
106+ .betas (std::make_tuple (_beta1, _beta2))
107+ .weight_decay (_weight_decay)));
101108 }
102109 else if (_solver_type == " RMSPROP" )
103110 {
104111 _optimizer = std::unique_ptr<torch::optim::Optimizer>(
105112 new torch::optim::RMSprop (
106113 _params, torch::optim::RMSpropOptions (_base_lr).weight_decay (
107114 _weight_decay)));
108- this ->_logger ->info (" base_lr: {}" , _base_lr);
109115 }
110116 else if (_solver_type == " ADAGRAD" )
111117 {
112118 _optimizer = std::unique_ptr<torch::optim::Optimizer>(
113119 new torch::optim::Adagrad (
114120 _params, torch::optim::AdagradOptions (_base_lr).weight_decay (
115121 _weight_decay)));
116- this ->_logger ->info (" base_lr: {}" , _base_lr);
117122 }
118123 else if (_solver_type == " RANGER" || _solver_type == " RANGER_PLUS" )
119124 {
@@ -131,7 +136,6 @@ namespace dd
131136 .adamp (_adamp)
132137 .lsteps (_lsteps)
133138 .lalpha (_lalpha)));
134- this ->_logger ->info (" base_lr: {}" , _base_lr);
135139 this ->_logger ->info (" beta_1: {}" , _beta1);
136140 this ->_logger ->info (" beta_2: {}" , _beta2);
137141 this ->_logger ->info (" weight_decay: {}" , _weight_decay);
@@ -162,7 +166,6 @@ namespace dd
162166 .lookahead (_lookahead)
163167 .lsteps (_lsteps)
164168 .lalpha (_lalpha)));
165- this ->_logger ->info (" base_lr: {}" , _base_lr);
166169 this ->_logger ->info (" momentum: {}" , _momentum);
167170 this ->_logger ->info (" weight_decay: {}" , _weight_decay);
168171 this ->_logger ->info (" lookahead: {}" , _lookahead);
@@ -180,7 +183,6 @@ namespace dd
180183 _optimizer
181184 = std::unique_ptr<torch::optim::Optimizer>(new torch::optim::SGD (
182185 _params, torch::optim::SGDOptions (_base_lr)));
183- this ->_logger ->info (" base_lr: {}" , _base_lr);
184186 }
185187 this ->_logger ->info (" clip: {}" , _clip);
186188 if (_clip)
@@ -199,6 +201,8 @@ namespace dd
199201 }
200202 if (_sam)
201203 this ->_logger ->info (" using Sharpness Aware Minimization (SAM)" );
204+ this ->_logger ->info (" using optimizer " + _solver_type);
205+ this ->_logger ->info (" base_lr: {}" , _base_lr);
202206 }
203207
204208 void TorchSolver::sam_first_step ()
@@ -417,6 +421,14 @@ namespace dd
417421 options.betas (std::make_tuple (_beta1, _beta2));
418422 options.weight_decay (_weight_decay);
419423 }
424+ else if (_solver_type == " ADAMW" )
425+ {
426+ auto &options = static_cast <torch::optim::AdamWOptions &>(
427+ param_group.options ());
428+ options.lr (_base_lr);
429+ options.betas (std::make_tuple (_beta1, _beta2));
430+ options.weight_decay (_weight_decay);
431+ }
420432 else if (_solver_type == " RMSPROP" )
421433 {
422434 auto &options = static_cast <torch::optim::RMSpropOptions &>(
0 commit comments