Skip to content

Commit 31552db

Browse files
authored
add data type; refactor methods (#299)
* add data type; refactor methods * add DataType to manage types of data * refactor data checking to check type and shape of data ** fix incorrect data type in some codes * refactor methods including sub_system, append, sort_atom_types, shuffle, pick_atom_idx, and merge methods in LabeledSystem into those in System * small fix * fix the shape of formal_charges. Looks like the docstring is wrong * remove wrong import * fix system class in sub_system
1 parent 5c4008d commit 31552db

File tree

8 files changed

+214
-172
lines changed

8 files changed

+214
-172
lines changed

dpdata/bond_order_system.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
#%%
22
# Bond Order System
3-
from dpdata.system import System, LabeledSystem, check_System, load_format
3+
import numpy as np
4+
from dpdata.system import System, LabeledSystem, load_format, DataType, Axis
45
import dpdata.rdkit.utils
56
from dpdata.rdkit.sanitize import Sanitizer, SanitizeError
67
from copy import deepcopy
78
from rdkit.Chem import Conformer
89
# import dpdata.rdkit.mol2
910

10-
def check_BondOrderSystem(data):
11-
check_System(data)
12-
assert ('bonds' in data.keys())
1311

1412
class BondOrderSystem(System):
1513
'''
@@ -23,6 +21,11 @@ class BondOrderSystem(System):
2321
1 - single bond, 2 - double bond, 3 - triple bond, 1.5 - aromatic bond
2422
- `d_example['formal_charges']` : a numpy array of size 5 x 1
2523
'''
24+
DTYPES = System.DTYPES + (
25+
DataType("bonds", np.ndarray, (Axis.NBONDS, 3)),
26+
DataType("formal_charges", np.ndarray, (Axis.NATOMS,)),
27+
)
28+
2629
def __init__(self,
2730
file_name = None,
2831
fmt = 'auto',
@@ -86,6 +89,7 @@ def __init__(self,
8689

8790
if type_map:
8891
self.apply_type_map(type_map)
92+
self.check_data()
8993

9094
def from_fmt_obj(self, fmtobj, file_name, **kwargs):
9195
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
@@ -104,9 +108,6 @@ def to_fmt_obj(self, fmtobj, *args, **kwargs):
104108
self.rdkit_mol.AddConformer(conf, assignId=True)
105109
return fmtobj.to_bond_order_system(self.data, self.rdkit_mol, *args, **kwargs)
106110

107-
def __repr__(self):
108-
return self.__str__()
109-
110111
def __str__(self):
111112
'''
112113
A brief summary of the system

dpdata/cp2k/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def handle_single_xyz_frame(self, lines):
268268
#info_dict['atom_types'] = np.asarray(atom_types_list)
269269
info_dict['coords'] = np.asarray([coords_list]).astype('float32')
270270
info_dict['energies'] = np.array([energy]).astype('float32')
271-
info_dict['orig']=[0,0,0]
271+
info_dict['orig'] = np.zeros(3)
272272
return info_dict
273273

274274
#%%

dpdata/plugins/ase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict:
4343
'atom_types': atom_types,
4444
'cells': np.array([cells]).astype('float32'),
4545
'coords': np.array([coords]).astype('float32'),
46-
'orig': [0,0,0],
46+
'orig': np.zeros(3),
4747
}
4848
return info_dict
4949

dpdata/pwmat/atomconfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _to_system_data_lower(lines) :
1111
for kk in range(idx+1,idx+1+3):
1212
vector=[float(jj) for jj in lines[kk].split()[0:3]]
1313
cell.append(vector)
14-
system['cells'] = [np.array(cell)]
14+
system['cells'] = np.array([cell])
1515
coord = []
1616
atomic_number = []
1717
atom_numbs = []
@@ -32,7 +32,7 @@ def _to_system_data_lower(lines) :
3232
for ii in np.unique(sorted(atomic_number)) :
3333
atom_numbs.append(atomic_number.count(ii))
3434
system['atom_numbs'] = [int(ii) for ii in atom_numbs]
35-
system['coords'] = [np.array(coord)]
35+
system['coords'] = np.array([coord])
3636
system['orig'] = np.zeros(3)
3737
atom_types = []
3838
for idx,ii in enumerate(system['atom_numbs']) :

dpdata/pymatgen/molecule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ def to_system_data(file_name, protect_layer = 9) :
2424
# center = [c - h_cell_size for c in mol.center_of_mass]
2525
system['orig'] = np.array([0, 0, 0])
2626

27-
system['coords'] = [tmpcoord]
28-
system['cells'] = [10.0 * np.eye(3)]
27+
system['coords'] = np.array([tmpcoord])
28+
system['cells'] = np.array([10.0 * np.eye(3)])
2929
return system

0 commit comments

Comments
 (0)