@@ -46,10 +46,31 @@ def _subprocess_run(command):
4646# - type embedding FP32, se_atten FP64
4747# - type embedding FP32, se_atten FP32
4848tests = [
49- {"se_atten precision" : "float64" , "type embedding precision" : "float64" },
50- {"se_atten precision" : "float64" , "type embedding precision" : "float32" },
51- {"se_atten precision" : "float32" , "type embedding precision" : "float64" },
52- {"se_atten precision" : "float32" , "type embedding precision" : "float32" },
49+ {
50+ "se_atten precision" : "float64" ,
51+ "type embedding precision" : "float64" ,
52+ "smooth_type_embdding" : True ,
53+ },
54+ {
55+ "se_atten precision" : "float64" ,
56+ "type embedding precision" : "float64" ,
57+ "smooth_type_embdding" : False ,
58+ },
59+ {
60+ "se_atten precision" : "float64" ,
61+ "type embedding precision" : "float32" ,
62+ "smooth_type_embdding" : True ,
63+ },
64+ {
65+ "se_atten precision" : "float32" ,
66+ "type embedding precision" : "float64" ,
67+ "smooth_type_embdding" : True ,
68+ },
69+ {
70+ "se_atten precision" : "float32" ,
71+ "type embedding precision" : "float32" ,
72+ "smooth_type_embdding" : True ,
73+ },
5374]
5475
5576
@@ -73,6 +94,9 @@ def _init_models():
7394 jdata ["model" ]["descriptor" ]["stripped_type_embedding" ] = True
7495 jdata ["model" ]["descriptor" ]["sel" ] = 120
7596 jdata ["model" ]["descriptor" ]["attn_layer" ] = 0
97+ jdata ["model" ]["descriptor" ]["smooth_type_embdding" ] = tests [i ][
98+ "smooth_type_embdding"
99+ ]
76100 jdata ["model" ]["type_embedding" ] = {}
77101 jdata ["model" ]["type_embedding" ]["precision" ] = tests [i ][
78102 "type embedding precision"
@@ -479,9 +503,15 @@ def test_1frame(self):
479503 self .assertEqual (ff1 .shape , (nframes , natoms , 3 ))
480504 self .assertEqual (vv1 .shape , (nframes , 9 ))
481505 # check values
482- np .testing .assert_almost_equal (ff0 , ff1 , default_places )
483- np .testing .assert_almost_equal (ee0 , ee1 , default_places )
484- np .testing .assert_almost_equal (vv0 , vv1 , default_places )
506+ np .testing .assert_almost_equal (
507+ ff0 , ff1 , default_places , err_msg = str (tests [i ])
508+ )
509+ np .testing .assert_almost_equal (
510+ ee0 , ee1 , default_places , err_msg = str (tests [i ])
511+ )
512+ np .testing .assert_almost_equal (
513+ vv0 , vv1 , default_places , err_msg = str (tests [i ])
514+ )
485515
486516 def test_1frame_atm (self ):
487517 for i in range (len (tests )):
0 commit comments