Skip to content

Commit 14cdc71

Browse files
committed
feat(ppsci): support data_effient_nopt
1 parent 6763666 commit 14cdc71

File tree

2 files changed

+11
-74
lines changed

2 files changed

+11
-74
lines changed

examples/data_efficient_nopt/data_efficient_nopt.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from ruamel.yaml.comments import CommentedMap as ruamelDict
3535
from scipy.stats import linregress
3636
from tqdm import tqdm
37-
from utils import logging_utils
38-
from visualdl import LogWriter
3937

4038
from ppsci.arch.data_efficient_nopt_model import YParams
4139
from ppsci.arch.data_efficient_nopt_model import build_fno
@@ -384,7 +382,6 @@ def train_one_epoch(self):
384382
inp = rearrange(inp, "b t c h w -> t b c h w")
385383
inp_blur = rearrange(inp_blur, "b t c h w -> t b c h w")
386384

387-
logwriter = LogWriter(logdir="./runs/data_effient_nopt")
388385
data_time += time.time() - data_start
389386
dtime = time.time() - data_start
390387

@@ -489,11 +486,6 @@ def train_one_epoch(self):
489486
f"Epoch {self.epoch} Batch {batch_idx} Train Loss {log_nrmse.item()}"
490487
)
491488
if self.log_to_screen:
492-
logwriter.add_scalar(
493-
"train_avg_loss",
494-
value=log_nrmse.item(),
495-
step=self.iters + steps - 1,
496-
)
497489
print(
498490
"Total Times. Global step: {}, Batch: {}, Rank: {}, Data Shape: {}, Data time: {}, Forward: {}, Backward: {}, Optimizer: {}".format(
499491
self.iters + steps - 1,
@@ -666,8 +658,8 @@ def train(cfg: DictConfig):
666658
device = f"gpu:{local_rank}" if paddle.device.cuda.device_count() >= 1 else "cpu"
667659
paddle.set_device(device)
668660

669-
params["batch_size"] = int(params.batch_size // world_size)
670-
params["startEpoch"] = 0
661+
params.batch_size = int(params.batch_size // world_size)
662+
params.startEpoch = 0
671663
if cfg.sweep_id:
672664
jid = os.environ["SLURM_JOBID"]
673665
expDir = os.path.join(
@@ -676,39 +668,23 @@ def train(cfg: DictConfig):
676668
else:
677669
expDir = os.path.join(params.exp_dir, cfg.config, str(cfg.run_name))
678670

679-
params["old_exp_dir"] = expDir
680-
params["experiment_dir"] = os.path.abspath(expDir)
681-
params["checkpoint_path"] = os.path.join(expDir, "training_checkpoints/ckpt.tar")
682-
params["best_checkpoint_path"] = os.path.join(
671+
params.old_exp_dir = expDir
672+
params.experiment_dir = os.path.abspath(expDir)
673+
params.checkpoint_path = os.path.join(expDir, "training_checkpoints/ckpt.tar")
674+
params.best_checkpoint_path = os.path.join(
683675
expDir, "training_checkpoints/best_ckpt.tar"
684676
)
685-
params["old_checkpoint_path"] = os.path.join(
677+
params.old_checkpoint_path = os.path.join(
686678
params.old_exp_dir, "training_checkpoints/best_ckpt.tar"
687679
)
688680

689681
if global_rank == 0:
690682
if not os.path.isdir(expDir):
691683
os.makedirs(expDir)
692684
os.makedirs(os.path.join(expDir, "training_checkpoints/"))
693-
params["resuming"] = True if os.path.isfile(params.checkpoint_path) else False
694-
695-
params["name"] = str(cfg.run_name)
696-
if global_rank == 0:
697-
logging_utils.log_to_file(
698-
logger_name=None, log_filename=os.path.join(expDir, "out.log")
699-
)
700-
logging_utils.log_versions()
701-
params.log()
702-
703-
if global_rank == 0:
704-
logging_utils.log_to_file(
705-
logger_name=None, log_filename=os.path.join(expDir, "out.log")
706-
)
707-
logging_utils.log_versions()
708-
params.log()
709-
710-
params["log_to_wandb"] = (global_rank == 0) and params["log_to_wandb"]
711-
params["log_to_screen"] = (global_rank == 0) and params["log_to_screen"]
685+
params.resuming = True if os.path.isfile(params.checkpoint_path) else False
686+
params.name = str(cfg.run_name)
687+
params.log_to_screen = (global_rank == 0) and params.log_to_screen
712688

713689
if global_rank == 0:
714690
hparams = ruamelDict()
@@ -728,7 +704,7 @@ def train(cfg: DictConfig):
728704

729705
@paddle.no_grad()
730706
def get_pred(cfg):
731-
with open(cfg.eval_config, "r") as stream:
707+
with open(cfg.infer_config, "r") as stream:
732708
config = yaml.load(stream, yaml.FullLoader)
733709
if cfg.ckpt_path:
734710
save_dir = os.path.join("/".join(cfg.ckpt_path.split("/")[:-1]), "results_icl")

ppsci/arch/data_efficient_nopt_model.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
except ImportError:
3232
pass
3333

34-
import logging
3534
import math
3635
import os
3736
from typing import List
@@ -2694,36 +2693,6 @@ def load(module, prefix=""):
26942693
_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
26952694

26962695

2697-
def config_logger(log_level=logging.INFO):
2698-
logging.basicConfig(format=_format, level=log_level)
2699-
2700-
2701-
def log_to_file(
2702-
logger_name=None, log_level=logging.INFO, log_filename="tensorflow.log"
2703-
):
2704-
2705-
if not os.path.exists(os.path.dirname(log_filename)):
2706-
os.makedirs(os.path.dirname(log_filename))
2707-
2708-
if logger_name is not None:
2709-
log = logging.getLogger(logger_name)
2710-
else:
2711-
log = logging.getLogger()
2712-
2713-
fh = logging.FileHandler(log_filename)
2714-
fh.setLevel(log_level)
2715-
fh.setFormatter(logging.Formatter(_format))
2716-
log.addHandler(fh)
2717-
2718-
2719-
def log_versions():
2720-
import paddle
2721-
2722-
logging.info("--------------- Versions ---------------")
2723-
logging.info("Paddle: " + str(paddle.__version__))
2724-
logging.info("----------------------------------------")
2725-
2726-
27272696
class LossMSE:
27282697
"""mse loss"""
27292698

@@ -3278,11 +3247,3 @@ def update_params(self, config):
32783247
for key, val in config.items():
32793248
self.params[key] = val
32803249
self.__setattr__(key, val)
3281-
3282-
def log(self):
3283-
logging.info("------------------ Configuration ------------------")
3284-
logging.info("Configuration file: " + str(self._yaml_filename))
3285-
logging.info("Configuration name: " + str(self._config_name))
3286-
for key, val in self.params.items():
3287-
logging.info(str(key) + " " + str(val))
3288-
logging.info("---------------------------------------------------")

0 commit comments

Comments
 (0)