Skip to content

Commit 26b0aa0

Browse files
restore some config
1 parent f2f6b81 commit 26b0aa0

File tree

6 files changed

+31
-12
lines changed

6 files changed

+31
-12
lines changed

examples/NLS-MB/NLS-MB_optical_rogue_wave.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def train(cfg: DictConfig):
261261
eval_freq=cfg.TRAIN.lbfgs.eval_freq,
262262
equation=equation,
263263
validator=validator,
264-
cfg=cfg,
265264
)
266265
# train model
267266
solver.train()

examples/NLS-MB/NLS-MB_optical_soliton.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def train(cfg: DictConfig):
238238
eval_freq=cfg.TRAIN.lbfgs.eval_freq,
239239
equation=equation,
240240
validator=validator,
241-
cfg=cfg,
242241
)
243242
# train model
244243
solver.train()

examples/aneurysm/aneurysm_flow.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,15 @@ def output_transform_p(self, in_, out):
196196
solver = ppsci.solver.Solver(
197197
model,
198198
constraint,
199-
optimizer=optimizer,
199+
cfg.output_dir,
200+
optimizer,
201+
log_freq=cfg.log_freq,
202+
epochs=cfg.TRAIN.epochs,
200203
iters_per_epoch=int(x.shape[0] / cfg.TRAIN.batch_size),
204+
save_freq=cfg.save_freq,
201205
equation=equation,
202-
cfg=cfg,
206+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
207+
checkpoint_path=cfg.TRAIN.checkpoint_path,
203208
)
204209
solver.train()
205210

examples/fourcastnet/train_finetune.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,15 @@ def train(cfg: DictConfig):
162162
solver = ppsci.solver.Solver(
163163
model,
164164
constraint,
165-
optimizer=optimizer,
165+
cfg.output_dir,
166+
optimizer,
167+
epochs=cfg.TRAIN.epochs,
166168
iters_per_epoch=ITERS_PER_EPOCH,
169+
eval_during_train=True,
167170
validator=validator,
168-
cfg=cfg,
171+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
172+
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
173+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
169174
)
170175
# train model
171176
solver.train()

examples/fourcastnet/train_precip.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,15 @@ def train(cfg: DictConfig):
159159
solver = ppsci.solver.Solver(
160160
model,
161161
constraint,
162-
iters_per_epoch=ITERS_PER_EPOCH,
163-
optimizer=optimizer,
162+
cfg.output_dir,
163+
optimizer,
164+
lr_scheduler,
165+
cfg.TRAIN.epochs,
166+
ITERS_PER_EPOCH,
167+
eval_during_train=True,
164168
validator=validator,
165-
cfg=cfg,
169+
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
170+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
166171
)
167172
# train model
168173
solver.train()

examples/fourcastnet/train_pretrain.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def train(cfg: DictConfig):
152152
solver = ppsci.solver.Solver(
153153
model,
154154
constraint,
155-
iters_per_epoch=ITERS_PER_EPOCH,
156-
optimizer=optimizer,
155+
cfg.output_dir,
156+
optimizer,
157+
lr_scheduler,
158+
cfg.TRAIN.epochs,
159+
ITERS_PER_EPOCH,
160+
eval_during_train=True,
161+
seed=cfg.seed,
157162
validator=validator,
158-
cfg=cfg,
163+
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
164+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
159165
)
160166
# train model
161167
solver.train()

0 commit comments

Comments
 (0)