Skip to content

Commit d4e806f

Browse files
committed
fix: UTs
1 parent 03f5392 commit d4e806f

File tree

3 files changed

+72
-73
lines changed

3 files changed

+72
-73
lines changed

tests/metrics/conftest.py

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,23 @@
88
DirectPredictRecord(
99
id=1,
1010
model_name="test_dp",
11-
task_name="ANI",
11+
task_name="AQM",
12+
create_time=None,
13+
energy_rmse=0.467693,
14+
energy_mae=0.340148,
15+
energy_rmse_natoms=0.0376691,
16+
energy_mae_natoms=0.0227907,
17+
force_rmse=0.437947,
18+
force_mae=0.253533,
19+
virial_rmse=None,
20+
virial_mae=None,
21+
virial_rmse_natoms=None,
22+
virial_mae_natoms=None,
23+
),
24+
DirectPredictRecord(
25+
id=2,
26+
model_name="test_dp",
27+
task_name="H_nature_2022",
1228
create_time=None,
1329
energy_rmse=0.467693,
1430
energy_mae=0.340148,
@@ -54,7 +70,7 @@
5470
virial_mae_natoms=1.47443,
5571
),
5672
DirectPredictRecord(
57-
id=3,
73+
id=4,
5874
model_name="test_dp",
5975
task_name="MoS2",
6076
create_time=None,
@@ -68,11 +84,10 @@
6884
virial_mae=1.52447,
6985
virial_rmse_natoms=0.109107,
7086
virial_mae_natoms=0.0660086,
71-
),
72-
DirectPredictRecord(
87+
),DirectPredictRecord(
7388
id=5,
7489
model_name="test_dp",
75-
task_name="MD22",
90+
task_name="AIMD_Chig",
7691
create_time=None,
7792
energy_rmse=0.13579,
7893
energy_mae=0.10127,
@@ -85,6 +100,22 @@
85100
virial_rmse_natoms=None,
86101
virial_mae_natoms=None,
87102
),
103+
DirectPredictRecord(
104+
id=7,
105+
model_name="test_dp",
106+
task_name="Carbon_growth",
107+
create_time=None,
108+
energy_rmse=17.596,
109+
energy_mae=14.1638,
110+
energy_rmse_natoms=0.451179,
111+
energy_mae_natoms=0.363173,
112+
force_rmse=0.221132,
113+
force_mae=0.133015,
114+
virial_rmse=2.74979,
115+
virial_mae=1.52447,
116+
virial_rmse_natoms=0.109107,
117+
virial_mae_natoms=0.0660086,
118+
),
88119
DirectPredictRecord(
89120
id=8,
90121
model_name="test_dp",
@@ -133,22 +164,38 @@
133164
virial_rmse_natoms=None,
134165
virial_mae_natoms=None,
135166
),
136-
# DirectPredictRecord(
137-
# id=12,
138-
# model_name="test_dp",
139-
# task_name="Cu_MgO_catalysts",
140-
# create_time=None,
141-
# energy_rmse=0.267982,
142-
# energy_mae=0.153377,
143-
# energy_rmse_natoms=0.0035446,
144-
# energy_mae_natoms=0.00229624,
145-
# force_rmse=0.0584197,
146-
# force_mae=0.038047,
147-
# virial_rmse=None,
148-
# virial_mae=None,
149-
# virial_rmse_natoms=None,
150-
# virial_mae_natoms=None,
151-
# ),
167+
DirectPredictRecord(
168+
id=11,
169+
model_name="test_dp",
170+
task_name="CompressBi",
171+
create_time=None,
172+
energy_rmse=0.267982,
173+
energy_mae=0.153377,
174+
energy_rmse_natoms=0.0035446,
175+
energy_mae_natoms=0.00229624,
176+
force_rmse=0.0584197,
177+
force_mae=0.038047,
178+
virial_rmse=2.74979,
179+
virial_mae=1.52447,
180+
virial_rmse_natoms=0.109107,
181+
virial_mae_natoms=0.0660086,
182+
),
183+
DirectPredictRecord(
184+
id=12,
185+
model_name="test_dp",
186+
task_name="In2O3_CO2",
187+
create_time=None,
188+
energy_rmse=0.267982,
189+
energy_mae=0.153377,
190+
energy_rmse_natoms=0.0035446,
191+
energy_mae_natoms=0.00229624,
192+
force_rmse=0.0584197,
193+
force_mae=0.038047,
194+
virial_rmse=2.74979,
195+
virial_mae=1.52447,
196+
virial_rmse_natoms=0.109107,
197+
virial_mae_natoms=0.0660086,
198+
),
152199
DirectPredictRecord(
153200
id=13,
154201
model_name="test_dp",
@@ -197,55 +244,7 @@
197244
virial_rmse_natoms=None,
198245
virial_mae_natoms=None,
199246
),
200-
## Deprecated
201-
# DirectPredictRecord(
202-
# id=4,
203-
# model_name="test_dp",
204-
# task_name="HEMC_HEMB",
205-
# create_time=None,
206-
# energy_rmse=7.87692,
207-
# energy_mae=4.38965,
208-
# energy_rmse_natoms=0.154871,
209-
# energy_mae_natoms=0.0900861,
210-
# force_rmse=0.19703,
211-
# force_mae=0.121378,
212-
# virial_rmse=6.12989,
213-
# virial_mae=2.92621,
214-
# virial_rmse_natoms=0.127266,
215-
# virial_mae_natoms=0.0619396,
216-
# ),
217-
# DirectPredictRecord(
218-
# id=16,
219-
# model_name="test_dp",
220-
# task_name="WBM_downsampled",
221-
# create_time=None,
222-
# energy_rmse=0.194829,
223-
# energy_mae=0.0604359,
224-
# energy_rmse_natoms=0.0318466,
225-
# energy_mae_natoms=0.00875549,
226-
# force_rmse=None,
227-
# force_mae=None,
228-
# virial_rmse=None,
229-
# virial_mae=None,
230-
# virial_rmse_natoms=None,
231-
# virial_mae_natoms=None,
232-
# ),
233-
# DirectPredictRecord(
234-
# id=17,
235-
# model_name="test_dp",
236-
# task_name="Subalex_9k",
237-
# create_time=None,
238-
# energy_rmse=1.90841,
239-
# energy_mae=0.268596,
240-
# energy_rmse_natoms=0.234027,
241-
# energy_mae_natoms=0.0286509,
242-
# force_rmse=0.624174,
243-
# force_mae=0.0437039,
244-
# virial_rmse=4.16581,
245-
# virial_mae=0.373998,
246-
# virial_rmse_natoms=0.382751,
247-
# virial_mae_natoms=0.0371473,
248-
# ),
247+
249248
]
250249

251250
RECORDS_CALCULATOR = [

tests/metrics/test_post_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_process_results_for_one_model(
2828
in caplog.text
2929
)
3030
assert result["generalizability_force_field_results"]["Weighted"] is None
31-
assert result["generalizability_force_field_results"]["ANI"]["energy_rmse"] == 467.7
31+
assert result["generalizability_force_field_results"]["MoS2"]["energy_rmse"] == 500.8
3232

3333
# Find differences between the calculator tasks and results
3434
calculator_task_differences = (

tests/metrics/test_visualization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def test_aggregate_ood_results_for_one_model(
1414
model.show_calculator_task = False
1515
aggregator = ResultsFetcher()
1616
result = aggregator.aggregate_ood_results_for_one_model(model=model)
17-
np.testing.assert_almost_equal(result["Molecules"], desired=0.22748765, decimal=5)
18-
np.testing.assert_almost_equal(result["Inorganic Materials"], 0.2972349, decimal=5)
17+
np.testing.assert_almost_equal(result["Molecules"], desired=0.28470115, decimal=5)
18+
np.testing.assert_almost_equal(result["Inorganic Materials"],desired=0.24101483, decimal=5)
1919
assert result["Catalysis"] is None
2020
with caplog.at_level(logging.WARNING):
2121
assert (

0 commit comments

Comments
 (0)