Skip to content

Commit 14f572e

Browse files
remove redundant seed fixing code and logger init code, and remove this behaviour to yaml
1 parent c30b6e0 commit 14f572e

File tree

91 files changed

+132
-811
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+132
-811
lines changed

examples/RegAE/RegAE.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,15 @@
1414

1515
from __future__ import annotations
1616

17-
from os import path as osp
18-
1917
import hydra
2018
import paddle
2119
from omegaconf import DictConfig
2220
from paddle.nn import functional as F
2321

2422
import ppsci
25-
from ppsci.utils import logger
2623

2724

2825
def train(cfg: DictConfig):
29-
# set random seed for reproducibility
30-
ppsci.utils.misc.set_random_seed(cfg.seed)
31-
# initialize logger
32-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")
33-
3426
# set model
3527
model = ppsci.arch.AutoEncoder(**cfg.MODEL)
3628

@@ -114,11 +106,6 @@ def loss_expr(output_dict, label_dict, weight_dict=None):
114106

115107

116108
def evaluate(cfg: DictConfig):
117-
# set random seed for reproducibility
118-
ppsci.utils.misc.set_random_seed(cfg.seed)
119-
# initialize logger
120-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")
121-
122109
# set model
123110
model = ppsci.arch.AutoEncoder(**cfg.MODEL)
124111

examples/RegAE/conf/RegAE.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ hydra:
1414
- mode
1515
- output_dir
1616
- log_freq
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
1720
sweep:
1821
# output directory for multirun
1922
dir: ${hydra.run.dir}

examples/amgnet/amgnet_airfoil.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
from os import path as osp
1817
from typing import TYPE_CHECKING
1918
from typing import Dict
2019
from typing import List
@@ -53,11 +52,6 @@ def eval_rmse_func(
5352

5453

5554
def train(cfg: DictConfig):
56-
# set random seed for reproducibility
57-
ppsci.utils.misc.set_random_seed(cfg.seed)
58-
# initialize logger
59-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")
60-
6155
# set airfoil model
6256
model = ppsci.arch.AMGNet(**cfg.MODEL)
6357

@@ -76,7 +70,6 @@ def train(cfg: DictConfig):
7670
"drop_last": False,
7771
"shuffle": True,
7872
},
79-
"num_workers": 1,
8073
}
8174

8275
# set constraint
@@ -102,11 +95,6 @@ def train(cfg: DictConfig):
10295
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
10396
},
10497
"batch_size": cfg.EVAL.batch_size,
105-
"sampler": {
106-
"name": "BatchSampler",
107-
"drop_last": False,
108-
"shuffle": False,
109-
},
11098
}
11199
rmse_validator = ppsci.validate.SupervisedValidator(
112100
eval_dataloader_cfg,
@@ -152,11 +140,6 @@ def train(cfg: DictConfig):
152140

153141

154142
def evaluate(cfg: DictConfig):
155-
# set random seed for reproducibility
156-
ppsci.utils.misc.set_random_seed(cfg.seed)
157-
# initialize logger
158-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")
159-
160143
# set airfoil model
161144
model = ppsci.arch.AMGNet(**cfg.MODEL)
162145

@@ -170,11 +153,6 @@ def evaluate(cfg: DictConfig):
170153
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
171154
},
172155
"batch_size": cfg.EVAL.batch_size,
173-
"sampler": {
174-
"name": "BatchSampler",
175-
"drop_last": False,
176-
"shuffle": False,
177-
},
178156
}
179157
rmse_validator = ppsci.validate.SupervisedValidator(
180158
eval_dataloader_cfg,

examples/amgnet/amgnet_cylinder.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
from os import path as osp
1817
from typing import TYPE_CHECKING
1918
from typing import Dict
2019
from typing import List
@@ -53,11 +52,6 @@ def eval_rmse_func(
5352

5453

5554
def train(cfg: DictConfig):
56-
# set random seed for reproducibility
57-
ppsci.utils.misc.set_random_seed(cfg.seed)
58-
# initialize logger
59-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")
60-
6155
# set cylinder model
6256
model = ppsci.arch.AMGNet(**cfg.MODEL)
6357

@@ -76,7 +70,6 @@ def train(cfg: DictConfig):
7670
"drop_last": False,
7771
"shuffle": True,
7872
},
79-
"num_workers": 1,
8073
}
8174

8275
# set constraint
@@ -102,11 +95,6 @@ def train(cfg: DictConfig):
10295
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
10396
},
10497
"batch_size": cfg.EVAL.batch_size,
105-
"sampler": {
106-
"name": "BatchSampler",
107-
"drop_last": False,
108-
"shuffle": False,
109-
},
11098
}
11199
rmse_validator = ppsci.validate.SupervisedValidator(
112100
eval_dataloader_cfg,
@@ -152,11 +140,6 @@ def train(cfg: DictConfig):
152140

153141

154142
def evaluate(cfg: DictConfig):
155-
# set random seed for reproducibility
156-
ppsci.utils.misc.set_random_seed(cfg.seed)
157-
# initialize logger
158-
logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")
159-
160143
# set airfoil model
161144
model = ppsci.arch.AMGNet(**cfg.MODEL)
162145

@@ -170,11 +153,6 @@ def evaluate(cfg: DictConfig):
170153
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
171154
},
172155
"batch_size": cfg.EVAL.batch_size,
173-
"sampler": {
174-
"name": "BatchSampler",
175-
"drop_last": False,
176-
"shuffle": False,
177-
},
178156
}
179157
rmse_validator = ppsci.validate.SupervisedValidator(
180158
eval_dataloader_cfg,

examples/amgnet/conf/amgnet_airfoil.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ hydra:
1414
- mode
1515
- output_dir
1616
- log_freq
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
1720
sweep:
1821
# output directory for multirun
1922
dir: ${hydra.run.dir}

examples/amgnet/conf/amgnet_cylinder.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ hydra:
1414
- mode
1515
- output_dir
1616
- log_freq
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
1720
sweep:
1821
# output directory for multirun
1922
dir: ${hydra.run.dir}

examples/biharmonic2d/biharmonic2d.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,6 @@ def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary):
6969

7070

7171
def train(cfg: DictConfig):
72-
# set random seed for reproducibility
73-
ppsci.utils.misc.set_random_seed(cfg.seed)
74-
# initialize logger
75-
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")
76-
7772
# set models
7873
disp_net = ppsci.arch.MLP(**cfg.MODEL)
7974

@@ -268,11 +263,6 @@ def train(cfg: DictConfig):
268263

269264

270265
def evaluate(cfg: DictConfig):
271-
# set random seed for reproducibility
272-
ppsci.utils.misc.set_random_seed(cfg.seed)
273-
# initialize logger
274-
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")
275-
276266
# set models
277267
disp_net = ppsci.arch.MLP(**cfg.MODEL)
278268

examples/biharmonic2d/conf/biharmonic2d.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ hydra:
1414
- mode
1515
- output_dir
1616
- log_freq
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
1720
sweep:
1821
# output directory for multirun
1922
dir: ${hydra.run.dir}

examples/bracket/bracket.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,6 @@ def train(cfg: DictConfig):
277277
"input": input_dict,
278278
"label": label_dict,
279279
},
280-
"sampler": {
281-
"name": "BatchSampler",
282-
"drop_last": False,
283-
"shuffle": False,
284-
},
285280
}
286281
sup_validator = ppsci.validate.SupervisedValidator(
287282
{**eval_dataloader_cfg, "batch_size": cfg.EVAL.batch_size.sup_validator},

examples/bubble/bubble.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@
3232

3333

3434
def train(cfg: DictConfig):
35-
# set random seed for reproducibility
36-
ppsci.utils.misc.set_random_seed(cfg.seed)
37-
# initialize logger
38-
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")
39-
4035
# load Data
4136
data = scipy.io.loadmat(cfg.DATA_PATH)
4237
# normalize data
@@ -171,11 +166,6 @@ def transform_out(in_, out):
171166
"label": test_label,
172167
},
173168
"batch_size": cfg.TRAIN.batch_size.mse_validator,
174-
"sampler": {
175-
"name": "BatchSampler",
176-
"drop_last": False,
177-
"shuffle": False,
178-
},
179169
},
180170
ppsci.loss.MSELoss("mean"),
181171
metric={"MSE": ppsci.metric.MSE()},
@@ -249,11 +239,6 @@ def transform_out(in_, out):
249239

250240

251241
def evaluate(cfg: DictConfig):
252-
# set random seed for reproducibility
253-
ppsci.utils.misc.set_random_seed(cfg.seed)
254-
# initialize logger
255-
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")
256-
257242
# load Data
258243
data = scipy.io.loadmat(cfg.DATA_PATH)
259244
# normalize data
@@ -343,11 +328,6 @@ def transform_out(in_, out):
343328
"label": test_label,
344329
},
345330
"batch_size": cfg.TRAIN.batch_size.mse_validator,
346-
"sampler": {
347-
"name": "BatchSampler",
348-
"drop_last": False,
349-
"shuffle": False,
350-
},
351331
},
352332
ppsci.loss.MSELoss("mean"),
353333
metric={"MSE": ppsci.metric.MSE()},

0 commit comments

Comments
 (0)