36
36
37
37
from ppsci .arch .data_efficient_nopt_model import YParams
38
38
from ppsci .arch .data_efficient_nopt_model import build_fno
39
- from ppsci .arch .data_efficient_nopt_model import build_vmae
40
39
from ppsci .arch .data_efficient_nopt_model import fno_pretrain as fno
41
40
from ppsci .arch .data_efficient_nopt_model import gaussian_blur
42
41
from ppsci .data .dataset .data_efficient_nopt_dataset import MixedDatasetLoader
@@ -211,8 +210,8 @@ def initialize_model(self, params):
211
210
elif self .params .mode == "finetune" :
212
211
logger .info ("Using Build FNO" )
213
212
self .model = build_fno (params )
214
- elif self . params . model_type == "vmae" :
215
- self . model = build_vmae ( params )
213
+ else :
214
+ raise NotImplementedError ( "Only support FNO for now" )
216
215
217
216
if dist .is_initialized ():
218
217
self .model = paddle .DataParallel (
@@ -230,14 +229,6 @@ def initialize_optimizer(self, params):
230
229
self .optimizer = optim .AdamW (
231
230
parameters = parameters , learning_rate = params .learning_rate
232
231
)
233
- elif params .optimizer == "adan" :
234
- raise NotImplementedError ("Adan not implemented yet" )
235
- elif params .optimizer == "sgd" :
236
- self .optimizer = optim .SGD (
237
- parameters = self .model .parameters (),
238
- learning_rate = params .learning_rate ,
239
- momentum = 0.9 ,
240
- )
241
232
else :
242
233
raise ValueError (f"Optimizer { params .optimizer } not supported" )
243
234
self .gscaler = amp .GradScaler (
@@ -609,13 +600,7 @@ def train(config: DictConfig):
609
600
610
601
params .batch_size = int (params .batch_size // world_size )
611
602
params .startEpoch = 0
612
- if config .sweep_id :
613
- jid = os .environ ["SLURM_JOBID" ]
614
- exp_dir = os .path .join (
615
- params .exp_dir , config .sweep_id , config .config , str (config .run_name ), jid
616
- )
617
- else :
618
- exp_dir = os .path .join (params .exp_dir , config .config , str (config .run_name ))
603
+ exp_dir = os .path .join (params .exp_dir , config .config , str (config .run_name ))
619
604
620
605
params .old_exp_dir = exp_dir
621
606
params .experiment_dir = os .path .abspath (exp_dir )
@@ -627,10 +612,9 @@ def train(config: DictConfig):
627
612
params .old_exp_dir , "training_checkpoints/best_ckpt.tar"
628
613
)
629
614
630
- if global_rank == 0 :
631
- if not os .path .isdir (exp_dir ):
632
- os .makedirs (exp_dir )
633
- os .makedirs (os .path .join (exp_dir , "training_checkpoints/" ))
615
+ if global_rank == 0 and not os .path .isdir (exp_dir ):
616
+ os .makedirs (exp_dir )
617
+ os .makedirs (os .path .join (exp_dir , "training_checkpoints/" ))
634
618
params .resuming = True if os .path .isfile (params .checkpoint_path ) else False
635
619
params .name = str (config .run_name )
636
620
params .log_to_screen = (global_rank == 0 ) and params .log_to_screen
0 commit comments