Skip to content

Commit ce524c5

Browse files
fix config unitest
1 parent 26b0aa0 commit ce524c5

File tree

3 files changed

+38
-54
lines changed

3 files changed

+38
-54
lines changed

ppsci/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ppsci.utils import config # isort:skip # noqa: F401
1516
from ppsci.utils import ema
1617
from ppsci.utils import initializer
1718
from ppsci.utils import logger

ppsci/utils/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
from typing_extensions import Literal
2323

24-
from ppsci.utils import misc
25-
2624
__all__ = []
2725

2826
if importlib.util.find_spec("pydantic") is not None:
@@ -314,7 +312,7 @@ def use_wandb_check(cls, v, info: ValidationInfo):
314312
if v and not isinstance(info.data["wandb_config"], dict):
315313
raise ValueError(
316314
"'wandb_config' should be a dict when 'use_wandb' is True, "
317-
f"but got {misc.typename(info.data['wandb_config'])}"
315+
f"but got {info.data['wandb_config'].__class__.__name__}"
318316
)
319317
return v
320318

test/utils/test_config.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2-
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
1+
import os
142

153
import hydra
164
import paddle
175
import pytest
18-
from omegaconf import DictConfig
6+
import yaml
197

20-
paddle.seed(1024)
8+
# 假设你的回调类在这个路径下
9+
from ppsci.utils.callbacks import InitCallback
2110

11+
# 设置 Paddle 的 seed
12+
paddle.seed(1024)
2213

14+
# 测试函数不需要装饰器
2315
@pytest.mark.parametrize(
2416
"epochs,mode,seed",
2517
[
@@ -28,42 +20,35 @@
2820
(10, "eval", -1),
2921
],
3022
)
31-
def test_invalid_epochs(
32-
epochs,
33-
mode,
34-
seed,
35-
):
36-
@hydra.main(version_base=None, config_path="./", config_name="test_config.yaml")
37-
def main(cfg: DictConfig):
38-
pass
39-
40-
# sys.exit will be called when validation error in pydantic, so there we use
41-
# SystemExit instead of other type of errors.
42-
with pytest.raises(SystemExit):
43-
cfg_dict = dict(
44-
{
45-
"TRAIN": {
46-
"epochs": epochs,
47-
},
48-
"mode": mode,
49-
"seed": seed,
50-
"hydra": {
51-
"callbacks": {
52-
"init_callback": {
53-
"_target_": "ppsci.utils.callbacks.InitCallback"
54-
}
55-
}
56-
},
23+
def test_invalid_epochs(tmpdir, epochs, mode, seed):
24+
cfg_dict = {
25+
"hydra": {
26+
"callbacks": {
27+
"init_callback": {"_target_": "ppsci.utils.callbacks.InitCallback"}
5728
}
58-
)
59-
# print(cfg_dict)
60-
import yaml
61-
62-
with open("test_config.yaml", "w") as f:
63-
yaml.dump(dict(cfg_dict), f)
64-
65-
main()
66-
67-
29+
},
30+
"mode": mode,
31+
"seed": seed,
32+
"TRAIN": {
33+
"epochs": epochs,
34+
},
35+
}
36+
# 创建一个临时的配置文件
37+
dir_ = os.path.dirname(__file__)
38+
config_abs_path = os.path.join(dir_, "test_config.yaml")
39+
with open(config_abs_path, "w") as f:
40+
f.write(yaml.dump(cfg_dict))
41+
42+
# 使用 hydra 的 compose API 来创建配置,而不是使用 main
43+
with hydra.initialize(config_path="./", version_base=None):
44+
cfg = hydra.compose(config_name="test_config.yaml")
45+
# 手动触发回调
46+
with pytest.raises(SystemExit) as exec_info:
47+
InitCallback().on_job_start(config=cfg)
48+
assert exec_info.value.code == 2
49+
# 你现在可以根据需要对 cfg 进行断言或进一步处理
50+
51+
52+
# 这部分通常不需要,除非你想直接从脚本运行测试
6853
if __name__ == "__main__":
6954
pytest.main()

0 commit comments

Comments
 (0)