Skip to content

Commit c0bb798

Browse files
authored
add HybridDriver (#292)
1 parent 31292fd commit c0bb798

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

dpdata/driver.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Driver plugin system."""
2-
from typing import Callable
2+
from typing import Callable, List, Union
33
from .plugin import Plugin
44
from abc import ABC, abstractmethod
55

@@ -78,3 +78,64 @@ def label(self, data: dict) -> dict:
7878
labeled data with energies and forces
7979
"""
8080
return NotImplemented
81+
82+
83+
@Driver.register("hybrid")
84+
class HybridDriver(Driver):
85+
"""Hybrid driver, with mixed drivers.
86+
87+
Parameters
88+
----------
89+
drivers : list[dict, Driver]
90+
list of drivers or drivers dict. For a dict, it should
91+
contain `type` as the name of the driver, and others
92+
are arguments of the driver.
93+
94+
Raises
95+
------
96+
TypeError
97+
The value of `drivers` is not a dict or `Driver`.
98+
99+
Examples
100+
--------
101+
>>> driver = HybridDriver([
102+
... {"type": "sqm", "qm_theory": "DFTB3"},
103+
... {"type": "dp", "dp": "frozen_model.pb"},
104+
... ])
105+
This driver is the hybrid of SQM and DP.
106+
"""
107+
def __init__(self, drivers: List[Union[dict, Driver]]) -> None:
108+
self.drivers = []
109+
for driver in drivers:
110+
if isinstance(driver, Driver):
111+
self.drivers.append(driver)
112+
elif isinstance(driver, dict):
113+
type = driver["type"]
114+
del driver["type"]
115+
self.drivers.append(Driver.get_driver(type)(**driver))
116+
else:
117+
raise TypeError("driver should be Driver or dict")
118+
119+
def label(self, data: dict) -> dict:
120+
"""Label a system data.
121+
122+
Energies and forces are the sum of those of each driver.
123+
124+
Parameters
125+
----------
126+
data : dict
127+
data with coordinates and atom types
128+
129+
Returns
130+
-------
131+
dict
132+
labeled data with energies and forces
133+
"""
134+
for ii, driver in enumerate(self.drivers):
135+
lb_data = driver.label(data.copy())
136+
if ii == 0:
137+
labeled_data = lb_data.copy()
138+
else:
139+
labeled_data['energies'] += lb_data ['energies']
140+
labeled_data['forces'] += lb_data ['forces']
141+
return labeled_data

tests/test_predict.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ def label(self, data):
1616
return data
1717

1818

19+
@dpdata.driver.Driver.register("one")
20+
class ZeroDriver(dpdata.driver.Driver):
21+
def label(self, data):
22+
nframes = data['coords'].shape[0]
23+
natoms = data['coords'].shape[1]
24+
data['energies'] = np.ones((nframes,))
25+
data['forces'] = np.ones((nframes, natoms, 3))
26+
data['virials'] = np.ones((nframes, 3, 3))
27+
return data
28+
29+
1930
class TestPredict(unittest.TestCase, CompLabeledSys):
2031
def setUp (self) :
2132
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
@@ -32,3 +43,29 @@ def setUp (self) :
3243
self.e_places = 6
3344
self.f_places = 6
3445
self.v_places = 6
46+
47+
48+
class TestHybridDriver(unittest.TestCase, CompLabeledSys):
49+
"""Test HybridDriver."""
50+
def setUp(self) :
51+
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
52+
fmt = 'deepmd/raw',
53+
type_map = ['O', 'H'])
54+
self.system_1 = ori_sys.predict([
55+
{"type": "one"},
56+
{"type": "one"},
57+
{"type": "one"},
58+
{"type": "zero"},
59+
],
60+
driver="hybrid")
61+
# sum is 3
62+
self.system_2 = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
63+
fmt = 'deepmd/raw',
64+
type_map = ['O', 'H'])
65+
for pp in ('energies', 'forces'):
66+
self.system_2.data[pp][:] = 3.
67+
68+
self.places = 6
69+
self.e_places = 6
70+
self.f_places = 6
71+
self.v_places = 6

0 commit comments

Comments
 (0)