@@ -158,12 +158,18 @@ def __init__(
158
158
cfg : Optional [DictConfig ] = None ,
159
159
):
160
160
self .cfg = cfg
161
+ if isinstance (cfg , DictConfig ):
162
+ # (Recommended)Params can be passed within cfg
163
+ # rather than passed to 'Solver.__init__' one-by-one.
164
+ self ._parse_params_from_cfg (cfg )
165
+
161
166
# set model
162
167
self .model = model
163
168
# set constraint
164
169
self .constraint = constraint
165
170
# set output directory
166
- self .output_dir = output_dir
171
+ if not cfg :
172
+ self .output_dir = output_dir
167
173
168
174
# set optimizer
169
175
self .optimizer = optimizer
@@ -192,19 +198,20 @@ def __init__(
192
198
)
193
199
194
200
# set training hyper-parameter
195
- self .epochs = epochs
196
- self .iters_per_epoch = iters_per_epoch
197
- # set update_freq for gradient accumulation
198
- self .update_freq = update_freq
199
- # set checkpoint saving frequency
200
- self .save_freq = save_freq
201
- # set logging frequency
202
- self .log_freq = log_freq
203
-
204
- # set evaluation hyper-parameter
205
- self .eval_during_train = eval_during_train
206
- self .start_eval_epoch = start_eval_epoch
207
- self .eval_freq = eval_freq
201
+ if not cfg :
202
+ self .epochs = epochs
203
+ self .iters_per_epoch = iters_per_epoch
204
+ # set update_freq for gradient accumulation
205
+ self .update_freq = update_freq
206
+ # set checkpoint saving frequency
207
+ self .save_freq = save_freq
208
+ # set logging frequency
209
+ self .log_freq = log_freq
210
+
211
+ # set evaluation hyper-parameter
212
+ self .eval_during_train = eval_during_train
213
+ self .start_eval_epoch = start_eval_epoch
214
+ self .eval_freq = eval_freq
208
215
209
216
# initialize training log(training loss, time cost, etc.) recorder during one epoch
210
217
self .train_output_info : Dict [str , misc .AverageMeter ] = {}
@@ -221,46 +228,45 @@ def __init__(
221
228
"reader_cost" : misc .AverageMeter ("reader_cost" , ".5f" , postfix = "s" ),
222
229
}
223
230
224
- # fix seed for reproducibility
225
- self .seed = seed
226
-
227
231
# set running device
228
- if device != "cpu" and paddle .device .get_device () == "cpu" :
232
+ if not cfg :
233
+ self .device = device
234
+ if self .device != "cpu" and paddle .device .get_device () == "cpu" :
229
235
logger .warning (f"Set device({ device } ) to 'cpu' for only cpu available." )
230
- device = "cpu"
231
- self .device = paddle .set_device (device )
236
+ self . device = "cpu"
237
+ self .device = paddle .set_device (self . device )
232
238
233
239
# set equations for physics-driven or data-physics hybrid driven task, such as PINN
234
240
self .equation = equation
235
241
236
- # set geometry for generating data
237
- self .geom = {} if geom is None else geom
238
-
239
242
# set validator
240
243
self .validator = validator
241
244
242
245
# set visualizer
243
246
self .visualizer = visualizer
244
247
245
248
# set automatic mixed precision(AMP) configuration
246
- self .use_amp = use_amp
247
- self .amp_level = amp_level
249
+ if not cfg :
250
+ self .use_amp = use_amp
251
+ self .amp_level = amp_level
248
252
self .scaler = amp .GradScaler (True ) if self .use_amp else None
249
253
250
254
# whether calculate metrics by each batch during evaluation, mainly for memory efficiency
251
- self .compute_metric_by_batch = compute_metric_by_batch
255
+ if not cfg :
256
+ self .compute_metric_by_batch = compute_metric_by_batch
252
257
if validator is not None :
253
258
for metric in itertools .chain (
254
259
* [_v .metric .values () for _v in self .validator .values ()]
255
260
):
256
- if metric .keep_batch ^ compute_metric_by_batch :
261
+ if metric .keep_batch ^ self . compute_metric_by_batch :
257
262
raise ValueError (
258
263
f"{ misc .typename (metric )} .keep_batch should be "
259
- f"{ compute_metric_by_batch } when compute_metric_by_batch="
260
- f"{ compute_metric_by_batch } ."
264
+ f"{ self . compute_metric_by_batch } when compute_metric_by_batch="
265
+ f"{ self . compute_metric_by_batch } ."
261
266
)
262
267
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
263
- self .eval_with_no_grad = eval_with_no_grad
268
+ if not cfg :
269
+ self .eval_with_no_grad = eval_with_no_grad
264
270
265
271
self .rank = dist .get_rank ()
266
272
self .world_size = dist .get_world_size ()
@@ -278,34 +284,37 @@ def __init__(
278
284
# set moving average model(optional)
279
285
self .ema_model = None
280
286
if self .cfg and any (key in self .cfg .TRAIN for key in ["ema" , "swa" ]):
281
- if "ema" in self .cfg .TRAIN :
282
- self .avg_freq = self .cfg .TRAIN .ema .avg_freq
287
+ if "ema" in self .cfg .TRAIN and cfg .TRAIN .ema .get ("use_ema" , False ):
283
288
self .ema_model = ema .ExponentialMovingAverage (
284
289
self .model , self .cfg .TRAIN .ema .decay
285
290
)
286
- elif "swa" in self .cfg .TRAIN :
287
- self .avg_freq = self .cfg .TRAIN .swa .avg_freq
291
+ elif "swa" in self .cfg .TRAIN and cfg .TRAIN .swa .get ("use_swa" , False ):
288
292
self .ema_model = ema .StochasticWeightAverage (self .model )
289
293
290
294
# load pretrained model, usually used for transfer learning
291
- self .pretrained_model_path = pretrained_model_path
292
- if pretrained_model_path is not None :
293
- save_load .load_pretrain (self .model , pretrained_model_path , self .equation )
295
+ if not cfg :
296
+ self .pretrained_model_path = pretrained_model_path
297
+ if self .pretrained_model_path is not None :
298
+ save_load .load_pretrain (
299
+ self .model , self .pretrained_model_path , self .equation
300
+ )
294
301
295
302
# initialize an dict for tracking best metric during training
296
303
self .best_metric = {
297
304
"metric" : float ("inf" ),
298
305
"epoch" : 0 ,
299
306
}
300
307
# load model checkpoint, usually used for resume training
301
- if checkpoint_path is not None :
302
- if pretrained_model_path is not None :
308
+ if not cfg :
309
+ self .checkpoint_path = checkpoint_path
310
+ if self .checkpoint_path is not None :
311
+ if self .pretrained_model_path is not None :
303
312
logger .warning (
304
313
"Detected 'pretrained_model_path' is given, weights in which might be"
305
314
"overridden by weights loaded from given 'checkpoint_path'."
306
315
)
307
316
loaded_metric = save_load .load_checkpoint (
308
- checkpoint_path ,
317
+ self . checkpoint_path ,
309
318
self .model ,
310
319
self .optimizer ,
311
320
self .scaler ,
@@ -366,7 +375,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
366
375
367
376
# set VisualDL tool
368
377
self .vdl_writer = None
369
- if use_vdl :
378
+ if not cfg :
379
+ self .use_vdl = use_vdl
380
+ if self .use_vdl :
370
381
with misc .RankZeroOnly (self .rank ) as is_master :
371
382
if is_master :
372
383
self .vdl_writer = vdl .LogWriter (osp .join (output_dir , "vdl" ))
@@ -377,7 +388,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
377
388
378
389
# set WandB tool
379
390
self .wandb_writer = None
380
- if use_wandb :
391
+ if not cfg :
392
+ self .use_wandb = use_wandb
393
+ if self .use_wandb :
381
394
try :
382
395
import wandb
383
396
except ModuleNotFoundError :
@@ -390,7 +403,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
390
403
391
404
# set TensorBoardX tool
392
405
self .tbd_writer = None
393
- if use_tbd :
406
+ if not cfg :
407
+ self .use_tbd = use_tbd
408
+ if self .use_tbd :
394
409
try :
395
410
import tensorboardX
396
411
except ModuleNotFoundError :
@@ -984,3 +999,43 @@ def plot_loss_history(
984
999
smooth_step = smooth_step ,
985
1000
use_semilogy = use_semilogy ,
986
1001
)
1002
+
1003
+ def _parse_params_from_cfg (self , cfg : DictConfig ):
1004
+ """
1005
+ Parse hyper-parameters from DictConfig.
1006
+ """
1007
+ self .output_dir = cfg .output_dir
1008
+ self .log_freq = cfg .log_freq
1009
+ self .use_tbd = cfg .use_tbd
1010
+ self .use_vdl = cfg .use_vdl
1011
+ self .wandb_config = cfg .wandb_config
1012
+ self .use_wandb = cfg .use_wandb
1013
+ self .device = cfg .device
1014
+ self .to_static = cfg .to_static
1015
+
1016
+ self .use_amp = cfg .use_amp
1017
+ self .amp_level = cfg .amp_level
1018
+
1019
+ self .epochs = cfg .TRAIN .epochs
1020
+ self .iters_per_epoch = cfg .TRAIN .iters_per_epoch
1021
+ self .update_freq = cfg .TRAIN .update_freq
1022
+ self .save_freq = cfg .TRAIN .save_freq
1023
+ self .eval_during_train = cfg .TRAIN .eval_during_train
1024
+ self .start_eval_epoch = cfg .TRAIN .start_eval_epoch
1025
+ self .eval_freq = cfg .TRAIN .eval_freq
1026
+ self .checkpoint_path = cfg .TRAIN .checkpoint_path
1027
+
1028
+ if "ema" in cfg .TRAIN and cfg .TRAIN .ema .get ("use_ema" , False ):
1029
+ self .avg_freq = cfg .TRAIN .ema .avg_freq
1030
+ elif "swa" in cfg .TRAIN and cfg .TRAIN .swa .get ("use_swa" , False ):
1031
+ self .avg_freq = cfg .TRAIN .swa .avg_freq
1032
+
1033
+ self .compute_metric_by_batch = cfg .EVAL .compute_metric_by_batch
1034
+ self .eval_with_no_grad = cfg .EVAL .eval_with_no_grad
1035
+
1036
+ if cfg .mode == "train" :
1037
+ self .pretrained_model_path = cfg .TRAIN .pretrained_model_path
1038
+ elif cfg .mode == "eval" :
1039
+ self .pretrained_model_path = cfg .EVAL .pretrained_model_path
1040
+ elif cfg .mode in ["export" , "infer" ]:
1041
+ self .pretrained_model_path = cfg .INFER .pretrained_model_path
0 commit comments