|
30 | 30 | from einops import rearrange
|
31 | 31 | from omegaconf import DictConfig
|
32 | 32 | from ruamel.yaml import YAML
|
33 |
| -from ruamel.yaml.comments import CommentedMap as ruamelDict |
34 | 33 | from scipy.stats import linregress
|
35 | 34 | from tqdm import tqdm
|
36 | 35 |
|
37 |
| -from ppsci.arch.data_efficient_nopt_model import YParams |
38 | 36 | from ppsci.arch.data_efficient_nopt_model import add_weight_decay
|
39 | 37 | from ppsci.arch.data_efficient_nopt_model import build_fno
|
40 | 38 | from ppsci.arch.data_efficient_nopt_model import fno_pretrain as fno
|
@@ -527,8 +525,16 @@ def train(self):
|
527 | 525 |
|
528 | 526 |
|
529 | 527 | def train(config: DictConfig):
|
530 |
| - params = YParams(config.train_config, config.config, config.mode) |
| 528 | + params = YAML() |
| 529 | + params._config_name = config.config |
| 530 | + params.params = {} |
| 531 | + params.mode = config.mode |
531 | 532 | params.use_ddp = config.use_ddp
|
| 533 | + for key, val in config.train_config[config.config].items(): |
| 534 | + val = None if val == "None" else val |
| 535 | + params.params[key] = val |
| 536 | + params.__setattr__(key, val) |
| 537 | + |
532 | 538 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
533 | 539 | global_rank = int(os.environ.get("RANK", 0))
|
534 | 540 | world_size = int(os.environ.get("WORLD_SIZE", 1))
|
@@ -559,13 +565,6 @@ def train(config: DictConfig):
|
559 | 565 | params.name = str(config.run_name)
|
560 | 566 | params.log_to_screen = (global_rank == 0) and params.log_to_screen
|
561 | 567 |
|
562 |
| - if global_rank == 0: |
563 |
| - hparams = ruamelDict() |
564 |
| - yaml = YAML() |
565 |
| - for key, value in params.params.items(): |
566 |
| - hparams[str(key)] = str(value) |
567 |
| - with open(os.path.join(exp_dir, "hyperparams.yaml"), "w") as hpfile: |
568 |
| - yaml.dump(hparams, hpfile) |
569 | 568 | trainer = Trainer(params, global_rank, local_rank, device, sweep_id=config.sweep_id)
|
570 | 569 | if config.sweep_id and trainer.global_rank == 0:
|
571 | 570 | print(config.sweep_id, trainer.params.entity, trainer.params.project)
|
|
0 commit comments