Skip to content

Commit 29ae4c2

Browse files
committed
feat(ppsci): support data_effient_nopt for training and test
1 parent e5b39df commit 29ae4c2

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

examples/data_efficient_nopt/config/operators_poisson.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ default: &DEFAULT
3131
entity: 'entity_name'
3232
project: 'proj_name'
3333
group: 'poisson'
34-
log_to_wandb: !!bool True
34+
log_to_wandb: !!bool False
3535
distill: !!bool False
3636
subsample: 1
3737
exp_dir: './exp/'
@@ -84,7 +84,7 @@ poisson-64-scale-e5_15: &poisson_64_e5_15
8484
scales_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e5_15_train_scale.npy'
8585
train_rand_idx_path: '/home/aistudio/data_efficient_nopt/data/possion_64/train_rand_idx.npy'
8686
batch_size: 128
87-
log_to_wandb: !!bool True
87+
log_to_wandb: !!bool False
8888
learning_rate: 1E-3
8989

9090
mode_cut: 32
@@ -107,7 +107,7 @@ pois-64-pretrain-e1_20: &pois_64_e1_20_pt
107107
scales_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e1_20_train_scale.npy'
108108
train_rand_idx_path: '/home/aistudio/data_efficient_nopt/data/possion_64/train_rand_idx.npy'
109109
batch_size: 128
110-
log_to_wandb: !!bool True
110+
log_to_wandb: !!bool False
111111
mode_cut: 32
112112
embed_cut: 64
113113
fc_cut: 2
@@ -128,7 +128,7 @@ pois-64-finetune-e5_15: &pois_64_e5_15_ft
128128
scales_path: '/home/aistudio/data_efficient_nopt/data/possion_64/poisson_64_e5_15_train_scale.npy'
129129
train_rand_idx_path: '/home/aistudio/data_efficient_nopt/data/possion_64/train_rand_idx.npy'
130130
batch_size: 128
131-
log_to_wandb: !!bool True
131+
log_to_wandb: !!bool False
132132
mode_cut: 32
133133
embed_cut: 64
134134
fc_cut: 2

examples/data_efficient_nopt/config/vmae_config_pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
basic_config: &basic_config
22
# Run settings
3-
log_to_wandb: !!bool True # Use wandb integration
3+
log_to_wandb: !!bool False # Use wandb integration
44
log_to_screen: !!bool True # Log progress to screen.
55
save_checkpoint: !!bool True # Save checkpoints
66
checkpoint_save_interval: 100 # Save every # epochs - also saves "best" according to val loss

examples/data_efficient_nopt/pretrain_basic.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def add_weight_decay(model, weight_decay=1e-5, inner_lr=1e-3, skip_list=()):
9191
decay = []
9292
no_decay = []
9393
for name, param in model.named_parameters():
94-
if not param.requires_grad:
94+
if param.stop_gradient:
9595
continue
9696
if len(param.squeeze().shape) <= 1 or name in skip_list:
9797
no_decay.append(param)
@@ -230,24 +230,32 @@ def initialize_optimizer(self, params):
230230
if params.optimizer == "adam":
231231
if self.params.learning_rate < 0:
232232
self.optimizer = DAdaptAdam(
233-
parameters, lr=1.0, growth_rate=1.05, log_every=100, decouple=True
233+
parameters,
234+
learning_rate=1.0,
235+
growth_rate=1.05,
236+
log_every=100,
237+
decouple=True,
234238
)
235239
else:
236-
self.optimizer = optim.AdamW(parameters, lr=params.learning_rate)
240+
self.optimizer = optim.AdamW(
241+
parameters=parameters, learning_rate=params.learning_rate
242+
)
237243
elif params.optimizer == "adan":
238244
# if self.params.learning_rate < 0:
239-
# self.optimizer = DAdaptAdan(parameters, lr=1., growth_rate=1.05, log_every=100)
245+
# self.optimizer = DAdaptAdan(parameters, learning_rate=1., growth_rate=1.05, log_every=100)
240246
# else:
241-
# self.optimizer = Adan(parameters, lr=params.learning_rate)
247+
# self.optimizer = Adan(parameters, learning_rate=params.learning_rate)
242248
raise NotImplementedError("Adan not implemented yet")
243249
elif params.optimizer == "sgd":
244250
self.optimizer = optim.SGD(
245-
self.model.parameters(), lr=params.learning_rate, momentum=0.9
251+
parameters=self.model.parameters(),
252+
learning_rate=params.learning_rate,
253+
momentum=0.9,
246254
)
247255
else:
248256
raise ValueError(f"Optimizer {params.optimizer} not supported")
249257
self.gscaler = amp.GradScaler(
250-
enabled=(self.mp_type == paddle.float16 and params.enable_amp)
258+
enable=(self.mp_type == paddle.float16 and params.enable_amp)
251259
)
252260

253261
def initialize_scheduler(self, params):

0 commit comments

Comments
 (0)