Skip to content

Commit 9dcc816

Browse files
authored
Fix shape mismatch when type_embedding is enabled and type_one_side i… (#1074)
* Fix shape mismatch when type_embedding is enabled and type_one_side is disabled. * Add UT to cover the case where batch_size is larger than 1. * Fix random issue in unit tests.
1 parent 03a05bd commit 9dcc816

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

deepmd/descriptor/se_a.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -667,17 +667,36 @@ def _concat_type_embedding(
667667
nframes,
668668
natoms,
669669
type_embedding,
670-
):
670+
):
671+
'''Concatenate `type_embedding` of neighbors and `xyz_scatter`.
672+
If not self.type_one_side, concatenate `type_embedding` of center atoms as well.
673+
674+
Parameters
675+
----------
676+
xyz_scatter:
677+
shape is [nframes*natoms[0]*self.nnei, 1]
678+
nframes:
679+
shape is []
680+
natoms:
681+
shape is [1+1+self.ntypes]
682+
type_embedding:
683+
shape is [self.ntypes, Y] where Y=jdata['type_embedding']['neuron'][-1]
684+
685+
Returns
686+
-------
687+
embedding:
688+
environment of each atom represented by embedding.
689+
'''
671690
te_out_dim = type_embedding.get_shape().as_list()[-1]
672-
nei_embed = tf.nn.embedding_lookup(type_embedding,tf.cast(self.nei_type,dtype=tf.int32)) #nnei*nchnl
673-
nei_embed = tf.tile(nei_embed,(nframes*natoms[0],1))
691+
nei_embed = tf.nn.embedding_lookup(type_embedding,tf.cast(self.nei_type,dtype=tf.int32)) # shape is [self.nnei, 1+te_out_dim]
692+
nei_embed = tf.tile(nei_embed,(nframes*natoms[0],1)) # shape is [nframes*natoms[0]*self.nnei, te_out_dim]
674693
nei_embed = tf.reshape(nei_embed,[-1,te_out_dim])
675-
embedding_input = tf.concat([xyz_scatter,nei_embed],1)
694+
embedding_input = tf.concat([xyz_scatter,nei_embed],1) # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim]
676695
if not self.type_one_side:
677-
atm_embed = embed_atom_type(self.ntypes, natoms, type_embedding)
678-
atm_embed = tf.tile(atm_embed,(1,self.nnei))
679-
atm_embed = tf.reshape(atm_embed,[-1,te_out_dim])
680-
embedding_input = tf.concat([embedding_input,atm_embed],1)
696+
atm_embed = embed_atom_type(self.ntypes, natoms, type_embedding) # shape is [natoms[0], te_out_dim]
697+
atm_embed = tf.tile(atm_embed,(nframes,self.nnei)) # shape is [nframes*natoms[0], self.nnei*te_out_dim]
698+
atm_embed = tf.reshape(atm_embed,[-1,te_out_dim]) # shape is [nframes*natoms[0]*self.nnei, te_out_dim]
699+
embedding_input = tf.concat([embedding_input,atm_embed],1) # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim+te_out_dim]
681700
return embedding_input
682701

683702

source/tests/common.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import pathlib
44

55
from deepmd.env import tf
6-
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
76
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
8-
from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION
97
from deepmd.common import j_loader as dp_j_loader
8+
from deepmd.utils import random as dp_random
109

1110
if GLOBAL_NP_FLOAT_PRECISION == np.float32 :
1211
global_default_fv_hh = 1e-2
@@ -26,8 +25,8 @@ def del_data():
2625
if os.path.isdir('system'):
2726
shutil.rmtree('system')
2827

29-
def gen_data() :
30-
tmpdata = Data(rand_pert = 0.1, seed = 1)
28+
def gen_data(nframes = 1) :
29+
tmpdata = Data(rand_pert = 0.1, seed = 1, nframes = nframes)
3130
sys = dpdata.LabeledSystem()
3231
sys.data['atom_names'] = ['foo', 'bar']
3332
sys.data['coords'] = tmpdata.coord
@@ -47,14 +46,15 @@ class Data():
4746
def __init__ (self,
4847
rand_pert = 0.1,
4948
seed = 1,
50-
box_scale = 20) :
49+
box_scale = 20,
50+
nframes = 1):
5151
coord = [[0.0, 0.0, 0.1], [1.1, 0.0, 0.1], [0.0, 1.1, 0.1],
5252
[4.0, 0.0, 0.0], [5.1, 0.0, 0.0], [4.0, 1.1, 0.0]]
53-
self.nframes = 1
53+
self.nframes = nframes
5454
self.coord = np.array(coord)
5555
self.coord = self._copy_nframes(self.coord)
56-
np.random.seed(seed)
57-
self.coord += rand_pert * np.random.random(self.coord.shape)
56+
dp_random.seed(seed)
57+
self.coord += rand_pert * dp_random.random(self.coord.shape)
5858
self.fparam = np.array([[0.1, 0.2]])
5959
self.aparam = np.tile(self.fparam, [1, 6])
6060
self.fparam = self._copy_nframes(self.fparam)
@@ -69,7 +69,7 @@ def __init__ (self,
6969
self.coord = self.coord.reshape([self.nframes, -1, 3])
7070
self.coord = self.coord[:,self.idx_map,:]
7171
self.coord = self.coord.reshape([self.nframes, -1])
72-
self.efield = np.random.random(self.coord.shape)
72+
self.efield = dp_random.random(self.coord.shape)
7373
self.atype = self.atype[self.idx_map]
7474
self.datype = self._copy_nframes(self.atype)
7575

@@ -128,7 +128,7 @@ def get_test_box_data (self,
128128
coord0_, box0_, type0_ = self.get_data()
129129
coord = coord0_[0]
130130
box = box0_[0]
131-
box += rand_pert * np.random.random(box.shape)
131+
box += rand_pert * dp_random.random(box.shape)
132132
atype = type0_[0]
133133
nframes = 1
134134
natoms = coord.size // 3

source/tests/test_descrpt_se_a_type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
GLOBAL_NP_FLOAT_PRECISION = np.float64
1818

1919
class TestModel(tf.test.TestCase):
20-
def setUp(self) :
21-
gen_data()
20+
def setUp(self):
21+
gen_data(nframes=2)
2222

2323
def test_descriptor_two_sides(self):
2424
jfile = 'water_se_a_type.json'
@@ -28,7 +28,7 @@ def test_descriptor_two_sides(self):
2828
set_pfx = j_must_have(jdata, 'set_prefix')
2929
batch_size = j_must_have(jdata, 'batch_size')
3030
test_size = j_must_have(jdata, 'numb_test')
31-
batch_size = 1
31+
batch_size = 2
3232
test_size = 1
3333
stop_batch = j_must_have(jdata, 'stop_batch')
3434
rcut = j_must_have (jdata['model']['descriptor'], 'rcut')

0 commit comments

Comments
 (0)