|
1 | 1 | """Driver plugin system."""
|
2 |
| -from typing import Callable |
| 2 | +from typing import Callable, List, Union |
3 | 3 | from .plugin import Plugin
|
4 | 4 | from abc import ABC, abstractmethod
|
5 | 5 |
|
@@ -78,3 +78,64 @@ def label(self, data: dict) -> dict:
|
78 | 78 | labeled data with energies and forces
|
79 | 79 | """
|
80 | 80 | return NotImplemented
|
| 81 | + |
| 82 | + |
| 83 | +@Driver.register("hybrid") |
| 84 | +class HybridDriver(Driver): |
| 85 | + """Hybrid driver, with mixed drivers. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + drivers : list[dict, Driver] |
| 90 | + list of drivers or drivers dict. For a dict, it should |
| 91 | + contain `type` as the name of the driver, and others |
| 92 | + are arguments of the driver. |
| 93 | +
|
| 94 | + Raises |
| 95 | + ------ |
| 96 | + TypeError |
| 97 | + The value of `drivers` is not a dict or `Driver`. |
| 98 | +
|
| 99 | + Examples |
| 100 | + -------- |
| 101 | + >>> driver = HybridDriver([ |
| 102 | + ... {"type": "sqm", "qm_theory": "DFTB3"}, |
| 103 | + ... {"type": "dp", "dp": "frozen_model.pb"}, |
| 104 | + ... ]) |
| 105 | + This driver is the hybrid of SQM and DP. |
| 106 | + """ |
| 107 | + def __init__(self, drivers: List[Union[dict, Driver]]) -> None: |
| 108 | + self.drivers = [] |
| 109 | + for driver in drivers: |
| 110 | + if isinstance(driver, Driver): |
| 111 | + self.drivers.append(driver) |
| 112 | + elif isinstance(driver, dict): |
| 113 | + type = driver["type"] |
| 114 | + del driver["type"] |
| 115 | + self.drivers.append(Driver.get_driver(type)(**driver)) |
| 116 | + else: |
| 117 | + raise TypeError("driver should be Driver or dict") |
| 118 | + |
| 119 | + def label(self, data: dict) -> dict: |
| 120 | + """Label a system data. |
| 121 | +
|
| 122 | + Energies and forces are the sum of those of each driver. |
| 123 | + |
| 124 | + Parameters |
| 125 | + ---------- |
| 126 | + data : dict |
| 127 | + data with coordinates and atom types |
| 128 | + |
| 129 | + Returns |
| 130 | + ------- |
| 131 | + dict |
| 132 | + labeled data with energies and forces |
| 133 | + """ |
| 134 | + for ii, driver in enumerate(self.drivers): |
| 135 | + lb_data = driver.label(data.copy()) |
| 136 | + if ii == 0: |
| 137 | + labeled_data = lb_data.copy() |
| 138 | + else: |
| 139 | + labeled_data['energies'] += lb_data ['energies'] |
| 140 | + labeled_data['forces'] += lb_data ['forces'] |
| 141 | + return labeled_data |
0 commit comments