Skip to content

Commit d928618

Browse files
author
Han Wang
committed
fix bug of getting ntypes from data
1 parent 8983e77 commit d928618

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

source/train/Data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def get_test (self, ntests = -1) :
144144
self.modifier.modify_data(ret)
145145
return ret
146146

147+
def get_ntypes(self) :
148+
return len(self.type_map)
149+
147150
def get_type_map(self) :
148151
return self.type_map
149152

source/train/DataSystem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__ (self,
5252
# natoms, nbatches
5353
ntypes = []
5454
for ii in self.data_systems :
55-
ntypes.append(np.max(ii.get_atom_type()) + 1)
55+
ntypes.append(ii.get_ntypes())
5656
self.sys_ntypes = max(ntypes)
5757
self.natoms = []
5858
self.natoms_vec = []

0 commit comments

Comments
 (0)