Skip to content

Commit f17d1b3

Browse files
[Ehn] Enhance config module (#899)
* support default config content in config module and remove deprecated AttrDict series code * update corresponding unitests * update develop code
1 parent 8155d71 commit f17d1b3

File tree

17 files changed

+300
-375
lines changed

17 files changed

+300
-375
lines changed

ppsci/arch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def build_model(cfg):
8181
"""Build model
8282
8383
Args:
84-
cfg (AttrDict): Arch config.
84+
cfg (DictConfig): Arch config.
8585
8686
Returns:
8787
nn.Layer: Model.

ppsci/constraint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict):
4242
"""Build constraint(s).
4343
4444
Args:
45-
cfg (List[AttrDict]): Constraint config list.
45+
cfg (List[DictConfig]): Constraint config list.
4646
equation_dict (Dct[str, Equation]): Equation(s) in dict.
4747
geom_dict (Dct[str, Geometry]): Geometry(ies) in dict.
4848

ppsci/data/dataset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def build_dataset(cfg) -> "io.Dataset":
7878
"""Build dataset
7979
8080
Args:
81-
cfg (List[AttrDict]): dataset config list.
81+
cfg (List[DictConfig]): dataset config list.
8282
8383
Returns:
8484
Dict[str, io.Dataset]: dataset.

ppsci/equation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_equation(cfg):
5454
"""Build equation(s)
5555
5656
Args:
57-
cfg (List[AttrDict]): Equation(s) config list.
57+
cfg (List[DictConfig]): Equation(s) config list.
5858
5959
Returns:
6060
Dict[str, Equation]: Equation(s) in dict.

ppsci/geometry/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_geometry(cfg):
5454
"""Build geometry(ies)
5555
5656
Args:
57-
cfg (List[AttrDict]): Geometry config list.
57+
cfg (List[DictConfig]): Geometry config list.
5858
5959
Returns:
6060
Dict[str, Geometry]: Geometry(ies) in dict.

ppsci/loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_loss(cfg):
5353
"""Build loss.
5454
5555
Args:
56-
cfg (AttrDict): Loss config.
56+
cfg (DictConfig): Loss config.
5757
Returns:
5858
Loss: Callable loss object.
5959
"""

ppsci/loss/mtl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def build_mtl_aggregator(cfg):
3535
"""Build loss aggregator with multi-task learning method.
3636
3737
Args:
38-
cfg (AttrDict): Aggregator config.
38+
cfg (DictConfig): Aggregator config.
3939
Returns:
4040
Loss: Callable loss aggregator object.
4141
"""

ppsci/metric/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def build_metric(cfg):
4343
"""Build metric.
4444
4545
Args:
46-
cfg (List[AttrDict]): List of metric config.
46+
cfg (List[DictConfig]): List of metric config.
4747
4848
Returns:
4949
Dict[str, Metric]: Dict of callable metric object.

ppsci/optimizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def build_lr_scheduler(cfg, epochs, iters_per_epoch):
3939
"""Build learning rate scheduler.
4040
4141
Args:
42-
cfg (AttrDict): Learning rate scheduler config.
42+
cfg (DictConfig): Learning rate scheduler config.
4343
epochs (int): Total epochs.
4444
iters_per_epoch (int): Number of iterations of one epoch.
4545
@@ -57,7 +57,7 @@ def build_optimizer(cfg, model_list, epochs, iters_per_epoch):
5757
"""Build optimizer and learning rate scheduler
5858
5959
Args:
60-
cfg (AttrDict): Learning rate scheduler config.
60+
cfg (DictConfig): Learning rate scheduler config.
6161
model_list (Tuple[nn.Layer, ...]): Tuple of model(s).
6262
epochs (int): Total epochs.
6363
iters_per_epoch (int): Number of iterations of one epoch.

ppsci/solver/solver.py

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,18 @@ def __init__(
158158
cfg: Optional[DictConfig] = None,
159159
):
160160
self.cfg = cfg
161+
if isinstance(cfg, DictConfig):
162+
# (Recommended)Params can be passed within cfg
163+
# rather than passed to 'Solver.__init__' one-by-one.
164+
self._parse_params_from_cfg(cfg)
165+
161166
# set model
162167
self.model = model
163168
# set constraint
164169
self.constraint = constraint
165170
# set output directory
166-
self.output_dir = output_dir
171+
if not cfg:
172+
self.output_dir = output_dir
167173

168174
# set optimizer
169175
self.optimizer = optimizer
@@ -192,19 +198,20 @@ def __init__(
192198
)
193199

194200
# set training hyper-parameter
195-
self.epochs = epochs
196-
self.iters_per_epoch = iters_per_epoch
197-
# set update_freq for gradient accumulation
198-
self.update_freq = update_freq
199-
# set checkpoint saving frequency
200-
self.save_freq = save_freq
201-
# set logging frequency
202-
self.log_freq = log_freq
203-
204-
# set evaluation hyper-parameter
205-
self.eval_during_train = eval_during_train
206-
self.start_eval_epoch = start_eval_epoch
207-
self.eval_freq = eval_freq
201+
if not cfg:
202+
self.epochs = epochs
203+
self.iters_per_epoch = iters_per_epoch
204+
# set update_freq for gradient accumulation
205+
self.update_freq = update_freq
206+
# set checkpoint saving frequency
207+
self.save_freq = save_freq
208+
# set logging frequency
209+
self.log_freq = log_freq
210+
211+
# set evaluation hyper-parameter
212+
self.eval_during_train = eval_during_train
213+
self.start_eval_epoch = start_eval_epoch
214+
self.eval_freq = eval_freq
208215

209216
# initialize training log(training loss, time cost, etc.) recorder during one epoch
210217
self.train_output_info: Dict[str, misc.AverageMeter] = {}
@@ -221,46 +228,45 @@ def __init__(
221228
"reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"),
222229
}
223230

224-
# fix seed for reproducibility
225-
self.seed = seed
226-
227231
# set running device
228-
if device != "cpu" and paddle.device.get_device() == "cpu":
232+
if not cfg:
233+
self.device = device
234+
if self.device != "cpu" and paddle.device.get_device() == "cpu":
229235
logger.warning(f"Set device({device}) to 'cpu' for only cpu available.")
230-
device = "cpu"
231-
self.device = paddle.set_device(device)
236+
self.device = "cpu"
237+
self.device = paddle.set_device(self.device)
232238

233239
# set equations for physics-driven or data-physics hybrid driven task, such as PINN
234240
self.equation = equation
235241

236-
# set geometry for generating data
237-
self.geom = {} if geom is None else geom
238-
239242
# set validator
240243
self.validator = validator
241244

242245
# set visualizer
243246
self.visualizer = visualizer
244247

245248
# set automatic mixed precision(AMP) configuration
246-
self.use_amp = use_amp
247-
self.amp_level = amp_level
249+
if not cfg:
250+
self.use_amp = use_amp
251+
self.amp_level = amp_level
248252
self.scaler = amp.GradScaler(True) if self.use_amp else None
249253

250254
# whether calculate metrics by each batch during evaluation, mainly for memory efficiency
251-
self.compute_metric_by_batch = compute_metric_by_batch
255+
if not cfg:
256+
self.compute_metric_by_batch = compute_metric_by_batch
252257
if validator is not None:
253258
for metric in itertools.chain(
254259
*[_v.metric.values() for _v in self.validator.values()]
255260
):
256-
if metric.keep_batch ^ compute_metric_by_batch:
261+
if metric.keep_batch ^ self.compute_metric_by_batch:
257262
raise ValueError(
258263
f"{misc.typename(metric)}.keep_batch should be "
259-
f"{compute_metric_by_batch} when compute_metric_by_batch="
260-
f"{compute_metric_by_batch}."
264+
f"{self.compute_metric_by_batch} when compute_metric_by_batch="
265+
f"{self.compute_metric_by_batch}."
261266
)
262267
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
263-
self.eval_with_no_grad = eval_with_no_grad
268+
if not cfg:
269+
self.eval_with_no_grad = eval_with_no_grad
264270

265271
self.rank = dist.get_rank()
266272
self.world_size = dist.get_world_size()
@@ -278,34 +284,37 @@ def __init__(
278284
# set moving average model(optional)
279285
self.ema_model = None
280286
if self.cfg and any(key in self.cfg.TRAIN for key in ["ema", "swa"]):
281-
if "ema" in self.cfg.TRAIN:
282-
self.avg_freq = self.cfg.TRAIN.ema.avg_freq
287+
if "ema" in self.cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False):
283288
self.ema_model = ema.ExponentialMovingAverage(
284289
self.model, self.cfg.TRAIN.ema.decay
285290
)
286-
elif "swa" in self.cfg.TRAIN:
287-
self.avg_freq = self.cfg.TRAIN.swa.avg_freq
291+
elif "swa" in self.cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False):
288292
self.ema_model = ema.StochasticWeightAverage(self.model)
289293

290294
# load pretrained model, usually used for transfer learning
291-
self.pretrained_model_path = pretrained_model_path
292-
if pretrained_model_path is not None:
293-
save_load.load_pretrain(self.model, pretrained_model_path, self.equation)
295+
if not cfg:
296+
self.pretrained_model_path = pretrained_model_path
297+
if self.pretrained_model_path is not None:
298+
save_load.load_pretrain(
299+
self.model, self.pretrained_model_path, self.equation
300+
)
294301

295302
# initialize an dict for tracking best metric during training
296303
self.best_metric = {
297304
"metric": float("inf"),
298305
"epoch": 0,
299306
}
300307
# load model checkpoint, usually used for resume training
301-
if checkpoint_path is not None:
302-
if pretrained_model_path is not None:
308+
if not cfg:
309+
self.checkpoint_path = checkpoint_path
310+
if self.checkpoint_path is not None:
311+
if self.pretrained_model_path is not None:
303312
logger.warning(
304313
"Detected 'pretrained_model_path' is given, weights in which might be"
305314
"overridden by weights loaded from given 'checkpoint_path'."
306315
)
307316
loaded_metric = save_load.load_checkpoint(
308-
checkpoint_path,
317+
self.checkpoint_path,
309318
self.model,
310319
self.optimizer,
311320
self.scaler,
@@ -366,7 +375,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
366375

367376
# set VisualDL tool
368377
self.vdl_writer = None
369-
if use_vdl:
378+
if not cfg:
379+
self.use_vdl = use_vdl
380+
if self.use_vdl:
370381
with misc.RankZeroOnly(self.rank) as is_master:
371382
if is_master:
372383
self.vdl_writer = vdl.LogWriter(osp.join(output_dir, "vdl"))
@@ -377,7 +388,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
377388

378389
# set WandB tool
379390
self.wandb_writer = None
380-
if use_wandb:
391+
if not cfg:
392+
self.use_wandb = use_wandb
393+
if self.use_wandb:
381394
try:
382395
import wandb
383396
except ModuleNotFoundError:
@@ -390,7 +403,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
390403

391404
# set TensorBoardX tool
392405
self.tbd_writer = None
393-
if use_tbd:
406+
if not cfg:
407+
self.use_tbd = use_tbd
408+
if self.use_tbd:
394409
try:
395410
import tensorboardX
396411
except ModuleNotFoundError:
@@ -984,3 +999,43 @@ def plot_loss_history(
984999
smooth_step=smooth_step,
9851000
use_semilogy=use_semilogy,
9861001
)
1002+
1003+
def _parse_params_from_cfg(self, cfg: DictConfig):
1004+
"""
1005+
Parse hyper-parameters from DictConfig.
1006+
"""
1007+
self.output_dir = cfg.output_dir
1008+
self.log_freq = cfg.log_freq
1009+
self.use_tbd = cfg.use_tbd
1010+
self.use_vdl = cfg.use_vdl
1011+
self.wandb_config = cfg.wandb_config
1012+
self.use_wandb = cfg.use_wandb
1013+
self.device = cfg.device
1014+
self.to_static = cfg.to_static
1015+
1016+
self.use_amp = cfg.use_amp
1017+
self.amp_level = cfg.amp_level
1018+
1019+
self.epochs = cfg.TRAIN.epochs
1020+
self.iters_per_epoch = cfg.TRAIN.iters_per_epoch
1021+
self.update_freq = cfg.TRAIN.update_freq
1022+
self.save_freq = cfg.TRAIN.save_freq
1023+
self.eval_during_train = cfg.TRAIN.eval_during_train
1024+
self.start_eval_epoch = cfg.TRAIN.start_eval_epoch
1025+
self.eval_freq = cfg.TRAIN.eval_freq
1026+
self.checkpoint_path = cfg.TRAIN.checkpoint_path
1027+
1028+
if "ema" in cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False):
1029+
self.avg_freq = cfg.TRAIN.ema.avg_freq
1030+
elif "swa" in cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False):
1031+
self.avg_freq = cfg.TRAIN.swa.avg_freq
1032+
1033+
self.compute_metric_by_batch = cfg.EVAL.compute_metric_by_batch
1034+
self.eval_with_no_grad = cfg.EVAL.eval_with_no_grad
1035+
1036+
if cfg.mode == "train":
1037+
self.pretrained_model_path = cfg.TRAIN.pretrained_model_path
1038+
elif cfg.mode == "eval":
1039+
self.pretrained_model_path = cfg.EVAL.pretrained_model_path
1040+
elif cfg.mode in ["export", "infer"]:
1041+
self.pretrained_model_path = cfg.INFER.pretrained_model_path

0 commit comments

Comments
 (0)