|
8 | 8 | from dance.settings import DANCEDIR |
9 | 9 |
|
10 | 10 | sys.path.append(str(DANCEDIR)) |
11 | | -from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web |
| 11 | +from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web, write_ans |
12 | 12 |
|
13 | 13 |
|
14 | 14 | # 测试 check_identical_strings 函数 |
@@ -90,72 +90,62 @@ def mock_settings(tmp_path, monkeypatch): |
90 | 90 |
|
91 | 91 |
|
92 | 92 | def test_write_ans(mock_settings): |
93 | | - # 使用mock_settings而不是创建新的临时目录 |
94 | 93 | sweep_results_dir = mock_settings / "sweep_results" |
95 | 94 | sweep_results_dir.mkdir(parents=True) |
| 95 | + output_file = sweep_results_dir / "heart_ans.csv" |
96 | 96 |
|
97 | | - # 创建测试数据 |
| 97 | + # 创建初始数据 |
98 | 98 | existing_data = pd.DataFrame({ |
99 | | - 'Dataset_id': ['dataset1', 'dataset2', 'dataset3'], |
100 | | - 'method1': ['url1', 'url2', 'url3'], |
101 | | - 'method1_best_yaml': ['yaml1', 'yaml2', 'yaml3'], |
102 | | - 'method1_best_res': [0.8, 0.9, 0.7] |
| 99 | + 'Dataset_id': ['dataset1', 'dataset2'], |
| 100 | + 'cta_actinn': ['url1', 'url2'], |
| 101 | + 'cta_actinn_best_yaml': ['yaml1', 'yaml2'], |
| 102 | + 'cta_actinn_best_res': [0.8, 0.7] |
103 | 103 | }) |
| 104 | + existing_data.to_csv(output_file) |
104 | 105 |
|
| 106 | + # 测试数据:包含较低分数和较高分数的情况 |
105 | 107 | new_data = pd.DataFrame({ |
106 | | - 'Dataset_id': ['dataset2', 'dataset3', 'dataset4'], # 部分重叠的数据 |
107 | | - 'method1': ['url2_new', 'url3_new', 'url4'], |
108 | | - 'method1_best_yaml': ['yaml2_new', 'yaml3_new', 'yaml4'], |
109 | | - 'method1_best_res': [0.9, 0.7, 0.85] # dataset2和dataset3的结果与现有数据相同 |
| 108 | + 'Dataset_id': ['dataset1', 'dataset2'], |
| 109 | + 'cta_actinn': ['url1_new', 'url2_new'], |
| 110 | + 'cta_actinn_best_yaml': ['yaml1_new', 'yaml2_new'], |
| 111 | + 'cta_actinn_best_res': [0.9, 0.6] # dataset1更高分数,dataset2更低分数 |
110 | 112 | }) |
111 | 113 |
|
112 | | - # 写入现有数据 |
113 | | - output_file = sweep_results_dir / "heart_ans.csv" |
114 | | - existing_data.to_csv(output_file) |
115 | | - |
116 | | - # 测试写入新数据 |
117 | | - from examples.atlas.get_result_web import write_ans |
118 | 114 | write_ans("heart", new_data, output_file) |
119 | 115 |
|
120 | | - # 读取合并后的结果 |
121 | | - merged_df = pd.read_csv(output_file, index_col=0) |
122 | | - |
123 | 116 | # 验证结果 |
124 | | - assert len(merged_df) == 4 # 应该有4个唯一的Dataset_id |
125 | | - assert 'dataset4' in merged_df.index # 新数据被添加 |
126 | | - assert merged_df.loc['dataset2', 'method1'] == 'url2_new' # 更新了已存在的数据 |
127 | | - |
128 | | - # 测试结果冲突的情况 |
129 | | - conflicting_data = pd.DataFrame({ |
130 | | - 'Dataset_id': ['dataset1'], |
131 | | - 'method1': ['url1_new'], |
132 | | - 'method1_best_yaml': ['yaml1_new'], |
133 | | - 'method1_best_res': [0.95] # 不同的结果值 |
134 | | - }) |
| 117 | + result_df = pd.read_csv(output_file) |
| 118 | + |
| 119 | + # 验证高分数更新成功 |
| 120 | + dataset1_row = result_df[result_df['Dataset_id'] == 'dataset1'].iloc[0] |
| 121 | + assert dataset1_row['cta_actinn_best_res'] == 0.9 |
| 122 | + assert dataset1_row['cta_actinn'] == 'url1_new' |
| 123 | + assert dataset1_row['cta_actinn_best_yaml'] == 'yaml1_new' |
135 | 124 |
|
136 | | - # 验证冲突数据会引发异常 |
137 | | - with pytest.raises(ValueError, match="结果冲突"): |
138 | | - write_ans("heart", conflicting_data) |
| 125 | + # 验证低分数保持不变 |
| 126 | + dataset2_row = result_df[result_df['Dataset_id'] == 'dataset2'].iloc[0] |
| 127 | + assert dataset2_row['cta_actinn_best_res'] == 0.7 |
| 128 | + assert dataset2_row['cta_actinn'] == 'url2' |
| 129 | + assert dataset2_row['cta_actinn_best_yaml'] == 'yaml2' |
139 | 130 |
|
140 | 131 |
|
141 | 132 | # 测试完全新的数据写入(文件不存在的情况) |
142 | 133 | def test_write_ans_new_file(mock_settings): |
143 | 134 | # 使用mock_settings而不是创建新的临时目录 |
144 | 135 | sweep_results_dir = mock_settings / "sweep_results" |
145 | 136 | sweep_results_dir.mkdir(parents=True) |
| 137 | + output_file = sweep_results_dir / "new_heart_ans.csv" |
146 | 138 |
|
147 | 139 | new_data = pd.DataFrame({ |
148 | 140 | 'Dataset_id': ['dataset1', 'dataset2'], |
149 | | - 'method1': ['url1', 'url2'], |
150 | | - 'method1_best_yaml': ['yaml1', 'yaml2'], |
151 | | - 'method1_best_res': [0.8, 0.9] |
| 141 | + 'cta_actinn': ['url1', 'url2'], |
| 142 | + 'cta_actinn_best_yaml': ['yaml1', 'yaml2'], |
| 143 | + 'cta_actinn_best_res': [0.8, 0.9] |
152 | 144 | }) |
153 | 145 |
|
154 | 146 | # 测试写入新文件 |
155 | | - from examples.atlas.get_result_web import write_ans |
156 | 147 |
|
157 | 148 | # 验证文件被创建并包含正确的数据 |
158 | | - output_file = sweep_results_dir / "heart_ans.csv" |
159 | 149 | write_ans("heart", new_data, output_file) |
160 | 150 | assert output_file.exists() |
161 | 151 |
|
|
0 commit comments