Skip to content

Commit a21443c

Browse files
committed
feat(ppsci): support data_effient_nopt
1 parent 9930bcb commit a21443c

File tree

2 files changed

+9
-41
lines changed

2 files changed

+9
-41
lines changed

examples/data_efficient_nopt/data_efficient_nopt.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@
3030
from einops import rearrange
3131
from omegaconf import DictConfig
3232
from ruamel.yaml import YAML
33-
from ruamel.yaml.comments import CommentedMap as ruamelDict
3433
from scipy.stats import linregress
3534
from tqdm import tqdm
3635

37-
from ppsci.arch.data_efficient_nopt_model import YParams
3836
from ppsci.arch.data_efficient_nopt_model import add_weight_decay
3937
from ppsci.arch.data_efficient_nopt_model import build_fno
4038
from ppsci.arch.data_efficient_nopt_model import fno_pretrain as fno
@@ -527,8 +525,16 @@ def train(self):
527525

528526

529527
def train(config: DictConfig):
530-
params = YParams(config.train_config, config.config, config.mode)
528+
params = YAML()
529+
params._config_name = config.config
530+
params.params = {}
531+
params.mode = config.mode
531532
params.use_ddp = config.use_ddp
533+
for key, val in config.train_config[config.config].items():
534+
val = None if val == "None" else val
535+
params.params[key] = val
536+
params.__setattr__(key, val)
537+
532538
local_rank = int(os.environ.get("LOCAL_RANK", 0))
533539
global_rank = int(os.environ.get("RANK", 0))
534540
world_size = int(os.environ.get("WORLD_SIZE", 1))
@@ -559,13 +565,6 @@ def train(config: DictConfig):
559565
params.name = str(config.run_name)
560566
params.log_to_screen = (global_rank == 0) and params.log_to_screen
561567

562-
if global_rank == 0:
563-
hparams = ruamelDict()
564-
yaml = YAML()
565-
for key, value in params.params.items():
566-
hparams[str(key)] = str(value)
567-
with open(os.path.join(exp_dir, "hyperparams.yaml"), "w") as hpfile:
568-
yaml.dump(hparams, hpfile)
569568
trainer = Trainer(params, global_rank, local_rank, device, sweep_id=config.sweep_id)
570569
if config.sweep_id and trainer.global_rank == 0:
571570
print(config.sweep_id, trainer.params.entity, trainer.params.project)

ppsci/arch/data_efficient_nopt_model.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -873,34 +873,3 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
873873

874874
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
875875
return img
876-
877-
878-
class YParams:
879-
"""Yaml file parser"""
880-
881-
def __init__(self, yaml_params, config_name, mode):
882-
self._config_name = config_name
883-
self.params = {}
884-
self.mode = mode
885-
886-
for key, val in yaml_params[config_name].items():
887-
if val == "None":
888-
val = None
889-
890-
self.params[key] = val
891-
self.__setattr__(key, val)
892-
893-
def __getitem__(self, key):
894-
return self.params[key]
895-
896-
def __setitem__(self, key, val):
897-
self.params[key] = val
898-
self.__setattr__(key, val)
899-
900-
def __contains__(self, key):
901-
return key in self.params
902-
903-
def update_params(self, config):
904-
for key, val in config.items():
905-
self.params[key] = val
906-
self.__setattr__(key, val)

0 commit comments

Comments
 (0)