Skip to content

Commit ea32d45

Browse files
authored
Merge pull request #168 from njzjz/refactor
refactor and add format plugin system
2 parents 59dfe7d + 4ca3706 commit ea32d45

30 files changed

+1105
-731
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,14 @@ print(syst.get_charge()) # return the total charge of the system
252252
```
253253

254254
If a valence of 3 is detected on carbon, the formal charge will be assigned to -1. Because for most cases (in alkynyl anion, isonitrile, cyclopentadienyl anion), the formal charge on 3-valence carbon is -1, and this is also consisent with the 8-electron rule.
255+
256+
# Plugins
257+
258+
One can follow [a simple example](plugin_example/) to add their own format by creating and installing plugins. It's crirical to add the [Format](dpdata/format.py) class to `entry_points['dpdata.plugins']` in `setup.py`:
259+
```py
260+
entry_points={
261+
'dpdata.plugins': [
262+
'random=dpdata_random:RandomFormat'
263+
]
264+
},
265+
```

dpdata/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
# BondOrder System has dependency on rdkit
1414
try:
15-
import rdkit
15+
# prevent conflict with dpdata.rdkit
16+
import rdkit as _
1617
USE_RDKIT = True
1718
except ModuleNotFoundError:
1819
USE_RDKIT = False

dpdata/bond_order_system.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#%%
22
# Bond Order System
3-
from dpdata.system import Register, System, LabeledSystem, check_System
4-
import rdkit.Chem
3+
from dpdata.system import System, LabeledSystem, check_System, load_format
54
import dpdata.rdkit.utils
65
from dpdata.rdkit.sanitize import Sanitizer, SanitizeError
76
from copy import deepcopy
@@ -87,8 +86,16 @@ def __init__(self,
8786
if type_map:
8887
self.apply_type_map(type_map)
8988

90-
register_from_funcs = Register()
91-
register_to_funcs = System.register_to_funcs + Register()
89+
def from_fmt_obj(self, fmtobj, file_name, **kwargs):
90+
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
91+
self.from_rdkit_mol(mol)
92+
if hasattr(fmtobj.from_bond_order_system, 'post_func'):
93+
for post_f in fmtobj.from_bond_order_system.post_func:
94+
self.post_funcs.get_plugin(post_f)(self)
95+
return self
96+
97+
def to_fmt_obj(self, fmtobj, *args, **kwargs):
98+
return fmtobj.to_bond_order_system(self.data, self.rdkit_mol, *args, **kwargs)
9299

93100
def __repr__(self):
94101
return self.__str__()
@@ -164,36 +171,3 @@ def from_rdkit_mol(self, rdkit_mol):
164171
self.data = dpdata.rdkit.utils.mol_to_system_data(rdkit_mol)
165172
self.data['bond_dict'] = dict([(f'{int(bond[0])}-{int(bond[1])}', bond[2]) for bond in self.data['bonds']])
166173
self.rdkit_mol = rdkit_mol
167-
168-
@register_from_funcs.register_funcs('mol')
169-
def from_mol_file(self, file_name):
170-
mol = rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False)
171-
self.from_rdkit_mol(mol)
172-
173-
@register_to_funcs.register_funcs("mol")
174-
def to_mol_file(self, file_name, frame_idx=0):
175-
assert (frame_idx < self.get_nframes())
176-
rdkit.Chem.MolToMolFile(self.rdkit_mol, file_name, confId=frame_idx)
177-
178-
@register_from_funcs.register_funcs("sdf")
179-
def from_sdf_file(self, file_name):
180-
'''
181-
Note that it requires all molecules in .sdf file must be of the same topology
182-
'''
183-
mols = [m for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False)]
184-
if len(mols) > 1:
185-
mol = dpdata.rdkit.utils.combine_molecules(mols)
186-
else:
187-
mol = mols[0]
188-
self.from_rdkit_mol(mol)
189-
190-
@register_to_funcs.register_funcs("sdf")
191-
def to_sdf_file(self, file_name, frame_idx=-1):
192-
sdf_writer = rdkit.Chem.SDWriter(file_name)
193-
if frame_idx == -1:
194-
for ii in self.get_nframes():
195-
sdf_writer.write(self.rdkit_mol, confId=ii)
196-
else:
197-
assert (frame_idx < self.get_nframes())
198-
sdf_writer.write(self.rdkit_mol, confId=frame_idx)
199-
sdf_writer.close()

dpdata/cp2k/output.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,20 +278,20 @@ def get_frames (fname) :
278278

279279
#conver to float array and add extra dimension for nframes
280280
cell = np.array(cell)
281-
cell = cell.astype(np.float)
281+
cell = cell.astype(float)
282282
cell = cell[np.newaxis, :, :]
283283
coord = np.array(coord)
284-
coord = coord.astype(np.float)
284+
coord = coord.astype(float)
285285
coord = coord[np.newaxis, :, :]
286286
atom_symbol_list = np.array(atom_symbol_list)
287287
force = np.array(force)
288-
force = force.astype(np.float)
288+
force = force.astype(float)
289289
force = force[np.newaxis, :, :]
290290

291291
# virial is not necessary
292292
if stress:
293293
stress = np.array(stress)
294-
stress = stress.astype(np.float)
294+
stress = stress.astype(float)
295295
stress = stress[np.newaxis, :, :]
296296
# stress to virial conversion, default unit in cp2k is GPa
297297
# note the stress is virial = stress * volume

dpdata/format.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Implement the format plugin system."""
2+
import os
3+
from collections import abc
4+
from abc import ABC
5+
6+
from .plugin import Plugin
7+
8+
9+
class Format(ABC):
10+
__FormatPlugin = Plugin()
11+
__FromPlugin = Plugin()
12+
__ToPlugin = Plugin()
13+
14+
@staticmethod
15+
def register(key):
16+
return Format.__FormatPlugin.register(key)
17+
18+
@staticmethod
19+
def register_from(key):
20+
return Format.__FromPlugin.register(key)
21+
22+
@staticmethod
23+
def register_to(key):
24+
return Format.__ToPlugin.register(key)
25+
26+
@staticmethod
27+
def get_formats():
28+
return Format.__FormatPlugin.plugins
29+
30+
@staticmethod
31+
def get_from_methods():
32+
return Format.__FromPlugin.plugins
33+
34+
@staticmethod
35+
def get_to_methods():
36+
return Format.__ToPlugin.plugins
37+
38+
@staticmethod
39+
def post(func_name):
40+
def decorator(object):
41+
if not isinstance(func_name, (list, tuple, set)):
42+
object.post_func = (func_name,)
43+
else:
44+
object.post_func = func_name
45+
return object
46+
return decorator
47+
48+
def from_system(self, file_name, **kwargs):
49+
"""System.from
50+
51+
Parameters
52+
----------
53+
file_name: str
54+
file name
55+
56+
Returns
57+
-------
58+
data: dict
59+
system data
60+
"""
61+
raise NotImplementedError("%s doesn't support System.from" %(self.__class__.__name__))
62+
63+
def to_system(self, data, *args, **kwargs):
64+
"""System.to
65+
66+
Parameters
67+
----------
68+
data: dict
69+
system data
70+
"""
71+
raise NotImplementedError("%s doesn't support System.to" %(self.__class__.__name__))
72+
73+
def from_labeled_system(self, file_name, **kwargs):
74+
raise NotImplementedError("%s doesn't support LabeledSystem.from" %(self.__class__.__name__))
75+
76+
def to_labeled_system(self, data, *args, **kwargs):
77+
return self.to_system(data, *args, **kwargs)
78+
79+
def from_bond_order_system(self, file_name, **kwargs):
80+
raise NotImplementedError("%s doesn't support BondOrderSystem.from" %(self.__class__.__name__))
81+
82+
def to_bond_order_system(self, data, rdkit_mol, *args, **kwargs):
83+
return self.to_system(data, *args, **kwargs)
84+
85+
class MultiModes:
86+
"""File mode for MultiSystems
87+
0 (default): not implemented
88+
1: every directory under the top-level directory is a system
89+
"""
90+
NotImplemented = 0
91+
Directory = 1
92+
93+
MultiMode = MultiModes.NotImplemented
94+
95+
def from_multi_systems(self, directory, **kwargs):
96+
"""MultiSystems.from
97+
98+
Parameters
99+
----------
100+
directory: str
101+
directory of system
102+
103+
Returns
104+
-------
105+
filenames: list[str]
106+
list of filenames
107+
"""
108+
if self.MultiMode == self.MultiModes.Directory:
109+
return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
110+
raise NotImplementedError("%s doesn't support MultiSystems.from" %(self.__class__.__name__))
111+
112+
def to_multi_systems(self, formulas, directory, **kwargs):
113+
if self.MultiMode == self.MultiModes.Directory:
114+
return [os.path.join(directory, ff) for ff in formulas]
115+
raise NotImplementedError("%s doesn't support MultiSystems.to" %(self.__class__.__name__))
116+

dpdata/plugin.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Base of plugin systems."""
2+
3+
4+
class Plugin:
5+
"""A class to register plugins.
6+
7+
Examples
8+
--------
9+
>>> Plugin = Register()
10+
>>> @Plugin.register("xx")
11+
def xxx():
12+
pass
13+
>>> print(Plugin.plugins['xx'])
14+
"""
15+
def __init__(self):
16+
self.plugins = {}
17+
18+
def register(self, key):
19+
"""Register a plugin.
20+
21+
Parameter
22+
---------
23+
key: str
24+
Key of the plugin.
25+
"""
26+
def decorator(object):
27+
self.plugins[key] = object
28+
return object
29+
return decorator
30+
31+
def get_plugin(self, key):
32+
return self.plugins[key]
33+
34+
def __add__(self, other):
35+
self.plugins.update(other.plugins)
36+
return self

dpdata/plugins/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import importlib
2+
from pathlib import Path
3+
try:
4+
from importlib import metadata
5+
except ImportError: # for Python<3.8
6+
import importlib_metadata as metadata
7+
8+
PACKAGE_BASE = "dpdata.plugins"
9+
NOT_LOADABLE = ("__init__.py",)
10+
11+
for module_file in Path(__file__).parent.glob("*.py"):
12+
if module_file.name not in NOT_LOADABLE:
13+
module_name = f".{module_file.stem}"
14+
importlib.import_module(module_name, PACKAGE_BASE)
15+
16+
# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html
17+
eps = metadata.entry_points().get('dpdata.plugins', [])
18+
for ep in eps:
19+
plugin = ep.load()

dpdata/plugins/abacus.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import dpdata.abacus.scf
2+
from dpdata.format import Format
3+
4+
5+
@Format.register("abacus/scf")
6+
@Format.register("abacus/pw/scf")
7+
class AbacusSCFFormat(Format):
8+
@Format.post("rot_lower_triangular")
9+
def from_labeled_system(self, file_name, **kwargs):
10+
return dpdata.abacus.scf.get_frame(file_name)

dpdata/plugins/amber.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import dpdata.amber.md
2+
import dpdata.amber.sqm
3+
from dpdata.format import Format
4+
5+
6+
@Format.register("amber/md")
7+
class AmberMDFormat(Format):
8+
def from_system(self, file_name=None, parm7_file=None, nc_file=None, use_element_symbols=None):
9+
# assume the prefix is the same if the spefic name is not given
10+
if parm7_file is None:
11+
parm7_file = file_name + ".parm7"
12+
if nc_file is None:
13+
nc_file = file_name + ".nc"
14+
return dpdata.amber.md.read_amber_traj(parm7_file=parm7_file, nc_file=nc_file, use_element_symbols=use_element_symbols, labeled=False)
15+
16+
def from_labeled_system(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_file=None, mden_file=None, mdout_file=None, use_element_symbols=None, **kwargs):
17+
# assume the prefix is the same if the spefic name is not given
18+
if parm7_file is None:
19+
parm7_file = file_name + ".parm7"
20+
if nc_file is None:
21+
nc_file = file_name + ".nc"
22+
if mdfrc_file is None:
23+
mdfrc_file = file_name + ".mdfrc"
24+
if mden_file is None:
25+
mden_file = file_name + ".mden"
26+
if mdout_file is None:
27+
mdout_file = file_name + ".mdout"
28+
return dpdata.amber.md.read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file, mdout_file, use_element_symbols)
29+
30+
31+
@Format.register("sqm/out")
32+
class SQMOutFormat(Format):
33+
def from_system(self, fname, **kwargs):
34+
'''
35+
Read from ambertools sqm.out
36+
'''
37+
return dpdata.amber.sqm.to_system_data(fname)

dpdata/plugins/ase.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from dpdata.format import Format
2+
3+
4+
@Format.register("ase/structure")
5+
class ASEStructureFormat(Format):
6+
def to_system(self, data, **kwargs):
7+
'''
8+
convert System to ASE Atom obj
9+
10+
'''
11+
from ase import Atoms
12+
13+
structures = []
14+
species = [data['atom_names'][tt] for tt in data['atom_types']]
15+
16+
for ii in range(data['coords'].shape[0]):
17+
structure = Atoms(
18+
symbols=species, positions=data['coords'][ii], pbc=not data.get('nopbc', False), cell=data['cells'][ii])
19+
structures.append(structure)
20+
21+
return structures
22+
23+
def to_labeled_system(self, data, *args, **kwargs):
24+
'''Convert System to ASE Atoms object.'''
25+
from ase import Atoms
26+
from ase.calculators.singlepoint import SinglePointCalculator
27+
28+
structures = []
29+
species = [data['atom_names'][tt] for tt in data['atom_types']]
30+
31+
for ii in range(data['coords'].shape[0]):
32+
structure = Atoms(
33+
symbols=species,
34+
positions=data['coords'][ii],
35+
pbc=not data.get('nopbc', False),
36+
cell=data['cells'][ii]
37+
)
38+
39+
results = {
40+
'energy': data["energies"][ii],
41+
'forces': data["forces"][ii]
42+
}
43+
if "virials" in data:
44+
# convert to GPa as this is ase convention
45+
v_pref = 1 * 1e4 / 1.602176621e6
46+
vol = structure.get_volume()
47+
results['stress'] = data["virials"][ii] / (v_pref * vol)
48+
49+
structure.calc = SinglePointCalculator(structure, **results)
50+
structures.append(structure)
51+
52+
return structures

0 commit comments

Comments
 (0)