34
34
from ruamel .yaml .comments import CommentedMap as ruamelDict
35
35
from scipy .stats import linregress
36
36
from tqdm import tqdm
37
- from utils import logging_utils
38
- from visualdl import LogWriter
39
37
40
38
from ppsci .arch .data_efficient_nopt_model import YParams
41
39
from ppsci .arch .data_efficient_nopt_model import build_fno
@@ -384,7 +382,6 @@ def train_one_epoch(self):
384
382
inp = rearrange (inp , "b t c h w -> t b c h w" )
385
383
inp_blur = rearrange (inp_blur , "b t c h w -> t b c h w" )
386
384
387
- logwriter = LogWriter (logdir = "./runs/data_effient_nopt" )
388
385
data_time += time .time () - data_start
389
386
dtime = time .time () - data_start
390
387
@@ -489,11 +486,6 @@ def train_one_epoch(self):
489
486
f"Epoch { self .epoch } Batch { batch_idx } Train Loss { log_nrmse .item ()} "
490
487
)
491
488
if self .log_to_screen :
492
- logwriter .add_scalar (
493
- "train_avg_loss" ,
494
- value = log_nrmse .item (),
495
- step = self .iters + steps - 1 ,
496
- )
497
489
print (
498
490
"Total Times. Global step: {}, Batch: {}, Rank: {}, Data Shape: {}, Data time: {}, Forward: {}, Backward: {}, Optimizer: {}" .format (
499
491
self .iters + steps - 1 ,
@@ -666,8 +658,8 @@ def train(cfg: DictConfig):
666
658
device = f"gpu:{ local_rank } " if paddle .device .cuda .device_count () >= 1 else "cpu"
667
659
paddle .set_device (device )
668
660
669
- params [ " batch_size" ] = int (params .batch_size // world_size )
670
- params [ " startEpoch" ] = 0
661
+ params . batch_size = int (params .batch_size // world_size )
662
+ params . startEpoch = 0
671
663
if cfg .sweep_id :
672
664
jid = os .environ ["SLURM_JOBID" ]
673
665
expDir = os .path .join (
@@ -676,39 +668,23 @@ def train(cfg: DictConfig):
676
668
else :
677
669
expDir = os .path .join (params .exp_dir , cfg .config , str (cfg .run_name ))
678
670
679
- params [ " old_exp_dir" ] = expDir
680
- params [ " experiment_dir" ] = os .path .abspath (expDir )
681
- params [ " checkpoint_path" ] = os .path .join (expDir , "training_checkpoints/ckpt.tar" )
682
- params [ " best_checkpoint_path" ] = os .path .join (
671
+ params . old_exp_dir = expDir
672
+ params . experiment_dir = os .path .abspath (expDir )
673
+ params . checkpoint_path = os .path .join (expDir , "training_checkpoints/ckpt.tar" )
674
+ params . best_checkpoint_path = os .path .join (
683
675
expDir , "training_checkpoints/best_ckpt.tar"
684
676
)
685
- params [ " old_checkpoint_path" ] = os .path .join (
677
+ params . old_checkpoint_path = os .path .join (
686
678
params .old_exp_dir , "training_checkpoints/best_ckpt.tar"
687
679
)
688
680
689
681
if global_rank == 0 :
690
682
if not os .path .isdir (expDir ):
691
683
os .makedirs (expDir )
692
684
os .makedirs (os .path .join (expDir , "training_checkpoints/" ))
693
- params ["resuming" ] = True if os .path .isfile (params .checkpoint_path ) else False
694
-
695
- params ["name" ] = str (cfg .run_name )
696
- if global_rank == 0 :
697
- logging_utils .log_to_file (
698
- logger_name = None , log_filename = os .path .join (expDir , "out.log" )
699
- )
700
- logging_utils .log_versions ()
701
- params .log ()
702
-
703
- if global_rank == 0 :
704
- logging_utils .log_to_file (
705
- logger_name = None , log_filename = os .path .join (expDir , "out.log" )
706
- )
707
- logging_utils .log_versions ()
708
- params .log ()
709
-
710
- params ["log_to_wandb" ] = (global_rank == 0 ) and params ["log_to_wandb" ]
711
- params ["log_to_screen" ] = (global_rank == 0 ) and params ["log_to_screen" ]
685
+ params .resuming = True if os .path .isfile (params .checkpoint_path ) else False
686
+ params .name = str (cfg .run_name )
687
+ params .log_to_screen = (global_rank == 0 ) and params .log_to_screen
712
688
713
689
if global_rank == 0 :
714
690
hparams = ruamelDict ()
@@ -728,7 +704,7 @@ def train(cfg: DictConfig):
728
704
729
705
@paddle .no_grad ()
730
706
def get_pred (cfg ):
731
- with open (cfg .eval_config , "r" ) as stream :
707
+ with open (cfg .infer_config , "r" ) as stream :
732
708
config = yaml .load (stream , yaml .FullLoader )
733
709
if cfg .ckpt_path :
734
710
save_dir = os .path .join ("/" .join (cfg .ckpt_path .split ("/" )[:- 1 ]), "results_icl" )
0 commit comments