Skip to content

Commit 13fa494

Browse files
committed
feat(ppsci): remove redundant codes
1 parent 488ea8c commit 13fa494

File tree

3 files changed

+100
-2670
lines changed

3 files changed

+100
-2670
lines changed

examples/data_efficient_nopt/data_efficient_nopt.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
from ppsci.arch.data_efficient_nopt_model import YParams
3838
from ppsci.arch.data_efficient_nopt_model import build_fno
39-
from ppsci.arch.data_efficient_nopt_model import build_vmae
4039
from ppsci.arch.data_efficient_nopt_model import fno_pretrain as fno
4140
from ppsci.arch.data_efficient_nopt_model import gaussian_blur
4241
from ppsci.data.dataset.data_efficient_nopt_dataset import MixedDatasetLoader
@@ -211,8 +210,8 @@ def initialize_model(self, params):
211210
elif self.params.mode == "finetune":
212211
logger.info("Using Build FNO")
213212
self.model = build_fno(params)
214-
elif self.params.model_type == "vmae":
215-
self.model = build_vmae(params)
213+
else:
214+
raise NotImplementedError("Only support FNO for now")
216215

217216
if dist.is_initialized():
218217
self.model = paddle.DataParallel(
@@ -230,14 +229,6 @@ def initialize_optimizer(self, params):
230229
self.optimizer = optim.AdamW(
231230
parameters=parameters, learning_rate=params.learning_rate
232231
)
233-
elif params.optimizer == "adan":
234-
raise NotImplementedError("Adan not implemented yet")
235-
elif params.optimizer == "sgd":
236-
self.optimizer = optim.SGD(
237-
parameters=self.model.parameters(),
238-
learning_rate=params.learning_rate,
239-
momentum=0.9,
240-
)
241232
else:
242233
raise ValueError(f"Optimizer {params.optimizer} not supported")
243234
self.gscaler = amp.GradScaler(
@@ -609,13 +600,7 @@ def train(config: DictConfig):
609600

610601
params.batch_size = int(params.batch_size // world_size)
611602
params.startEpoch = 0
612-
if config.sweep_id:
613-
jid = os.environ["SLURM_JOBID"]
614-
exp_dir = os.path.join(
615-
params.exp_dir, config.sweep_id, config.config, str(config.run_name), jid
616-
)
617-
else:
618-
exp_dir = os.path.join(params.exp_dir, config.config, str(config.run_name))
603+
exp_dir = os.path.join(params.exp_dir, config.config, str(config.run_name))
619604

620605
params.old_exp_dir = exp_dir
621606
params.experiment_dir = os.path.abspath(exp_dir)
@@ -627,10 +612,9 @@ def train(config: DictConfig):
627612
params.old_exp_dir, "training_checkpoints/best_ckpt.tar"
628613
)
629614

630-
if global_rank == 0:
631-
if not os.path.isdir(exp_dir):
632-
os.makedirs(exp_dir)
633-
os.makedirs(os.path.join(exp_dir, "training_checkpoints/"))
615+
if global_rank == 0 and not os.path.isdir(exp_dir):
616+
os.makedirs(exp_dir)
617+
os.makedirs(os.path.join(exp_dir, "training_checkpoints/"))
634618
params.resuming = True if os.path.isfile(params.checkpoint_path) else False
635619
params.name = str(config.run_name)
636620
params.log_to_screen = (global_rank == 0) and params.log_to_screen

0 commit comments

Comments
 (0)