Skip to content

Commit aab124f

Browse files
authored
Fix a potential slice bug in se_t descriptor (#1087)
* fix a potential slice bug in se_t * fix UT error * address comments
1 parent c824ff6 commit aab124f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

deepmd/descriptor/se_t.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ def _filter(self,
469469
inputs_i = tf.slice (inputs,
470470
[ 0, start_index_i *4],
471471
[-1, self.sel_a[type_i] *4] )
472+
start_index_j = start_index_i
472473
start_index_i += self.sel_a[type_i]
473474
nei_type_i = self.sel_a[type_i]
474475
shape_i = inputs_i.get_shape().as_list()
@@ -477,7 +478,6 @@ def _filter(self,
477478
env_i = tf.reshape(inputs_i, [-1, nei_type_i, 4])
478479
# with natom x nei_type_i x 3
479480
env_i = tf.slice(env_i, [0, 0, 1], [-1, -1, -1])
480-
start_index_j = 0
481481
for type_j in range(type_i, self.ntypes):
482482
# with natom x (nei_type_j x 4)
483483
inputs_j = tf.slice (inputs,

source/tests/test_model_se_t.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def test_model(self):
100100
np.savetxt('e.out', e.reshape([1, -1]))
101101
np.savetxt('f.out', f.reshape([1, -1]), delimiter = ',')
102102
np.savetxt('v.out', v.reshape([1, -1]), delimiter = ',')
103-
refe = [4.826771866004193612e+01]
104-
reff = [5.355088169393570574e+00,5.606772412401632266e+00,2.703270748296462966e-01,5.381408138049708967e+00,5.261355614357515975e+00,-4.079549918988090162e-01,-5.182324474551911919e+00,3.695481388907447262e-01,-5.238474288082559799e-02,1.665564584447352670e-01,-5.955401876564963892e+00,-2.217626865156164251e-01,-5.967343479332643419e+00,9.073821102416884665e-02,3.703103995504785639e-01,2.466151879965444438e-01,-5.373012500109097367e+00,4.146494691512622732e-02]
105-
refv = [-1.336768232407933077e+01,4.818050125305787801e-01,3.589284283410607568e-01,4.818050125305786691e-01,-1.225345559839458964e+01,-1.701405121682751653e-01,3.589284283410607568e-01,-1.701405121682752486e-01,-3.428455515842296353e-02]
103+
refe = [4.8436558582194039e+01]
104+
reff = [5.2896335066946598e+00,5.5778402259211131e+00,2.6839994229557251e-01,5.3528786387686784e+00,5.2477755362164968e+00,-4.0486366542657343e-01,-5.1297084055340498e+00,3.4607112287117253e-01,-5.1800783428369482e-02,1.5557068351407846e-01,-5.9071343228741506e+00,-2.2012359669589748e-01,-5.9156735320857488e+00,8.8397615509389127e-02,3.6701215949753935e-01,2.4729910864238122e-01,-5.3529501776440211e+00,4.1375943757728552e-02]
105+
refv = [-1.3159448660141607e+01,4.6952048725161544e-01,3.5482003698976106e-01,4.6952048725161577e-01,-1.2178990983673918e+01,-1.6867277410496895e-01,3.5482003698976106e-01,-1.6867277410496900e-01,-3.3986741457321945e-02]
106106
refe = np.reshape(refe, [-1])
107107
reff = np.reshape(reff, [-1])
108108
refv = np.reshape(refv, [-1])

0 commit comments

Comments
 (0)