Skip to content

Commit fc09e77

Browse files
wanghan-iapcmHan Wangpre-commit-ci[bot]
authored
unittest for the compression of smooth se_atten descriptor (#2916)
Co-authored-by: Han Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8bc4e3f commit fc09e77

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

source/tests/test_model_compression_se_atten.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,31 @@ def _subprocess_run(command):
4646
# - type embedding FP32, se_atten FP64
4747
# - type embedding FP32, se_atten FP32
4848
tests = [
49-
{"se_atten precision": "float64", "type embedding precision": "float64"},
50-
{"se_atten precision": "float64", "type embedding precision": "float32"},
51-
{"se_atten precision": "float32", "type embedding precision": "float64"},
52-
{"se_atten precision": "float32", "type embedding precision": "float32"},
49+
{
50+
"se_atten precision": "float64",
51+
"type embedding precision": "float64",
52+
"smooth_type_embdding": True,
53+
},
54+
{
55+
"se_atten precision": "float64",
56+
"type embedding precision": "float64",
57+
"smooth_type_embdding": False,
58+
},
59+
{
60+
"se_atten precision": "float64",
61+
"type embedding precision": "float32",
62+
"smooth_type_embdding": True,
63+
},
64+
{
65+
"se_atten precision": "float32",
66+
"type embedding precision": "float64",
67+
"smooth_type_embdding": True,
68+
},
69+
{
70+
"se_atten precision": "float32",
71+
"type embedding precision": "float32",
72+
"smooth_type_embdding": True,
73+
},
5374
]
5475

5576

@@ -73,6 +94,9 @@ def _init_models():
7394
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
7495
jdata["model"]["descriptor"]["sel"] = 120
7596
jdata["model"]["descriptor"]["attn_layer"] = 0
97+
jdata["model"]["descriptor"]["smooth_type_embdding"] = tests[i][
98+
"smooth_type_embdding"
99+
]
76100
jdata["model"]["type_embedding"] = {}
77101
jdata["model"]["type_embedding"]["precision"] = tests[i][
78102
"type embedding precision"
@@ -479,9 +503,15 @@ def test_1frame(self):
479503
self.assertEqual(ff1.shape, (nframes, natoms, 3))
480504
self.assertEqual(vv1.shape, (nframes, 9))
481505
# check values
482-
np.testing.assert_almost_equal(ff0, ff1, default_places)
483-
np.testing.assert_almost_equal(ee0, ee1, default_places)
484-
np.testing.assert_almost_equal(vv0, vv1, default_places)
506+
np.testing.assert_almost_equal(
507+
ff0, ff1, default_places, err_msg=str(tests[i])
508+
)
509+
np.testing.assert_almost_equal(
510+
ee0, ee1, default_places, err_msg=str(tests[i])
511+
)
512+
np.testing.assert_almost_equal(
513+
vv0, vv1, default_places, err_msg=str(tests[i])
514+
)
485515

486516
def test_1frame_atm(self):
487517
for i in range(len(tests)):

source/tests/test_model_se_atten.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,8 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
753753
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
754754
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
755755
jdata["model"]["descriptor"]["attn_layer"] = 1
756+
jdata["model"]["descriptor"]["rcut"] = 6.0
757+
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
756758
descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True)
757759
jdata["model"]["fitting_net"]["descrpt"] = descrpt
758760
fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)

0 commit comments

Comments
 (0)