|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( |
| 3 | + DPAtomicModel, |
| 4 | +) |
| 5 | +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( |
| 6 | + PairTabAtomicModel, |
| 7 | +) |
| 8 | +from deepmd.dpmodel.descriptor.base_descriptor import ( |
| 9 | + BaseDescriptor, |
| 10 | +) |
2 | 11 | from deepmd.dpmodel.descriptor.se_e2_a import ( |
3 | 12 | DescrptSeA, |
4 | 13 | ) |
|
8 | 17 | from deepmd.dpmodel.model.base_model import ( |
9 | 18 | BaseModel, |
10 | 19 | ) |
| 20 | +from deepmd.dpmodel.model.dp_zbl_model import ( |
| 21 | + DPZBLModel, |
| 22 | +) |
11 | 23 | from deepmd.dpmodel.model.ener_model import ( |
12 | 24 | EnergyModel, |
13 | 25 | ) |
@@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel: |
55 | 67 | ) |
56 | 68 |
|
57 | 69 |
|
| 70 | +def get_zbl_model(data: dict) -> DPZBLModel: |
| 71 | + data["descriptor"]["ntypes"] = len(data["type_map"]) |
| 72 | + descriptor = BaseDescriptor(**data["descriptor"]) |
| 73 | + fitting_type = data["fitting_net"].pop("type") |
| 74 | + if fitting_type == "ener": |
| 75 | + fitting = EnergyFittingNet( |
| 76 | + ntypes=descriptor.get_ntypes(), |
| 77 | + dim_descrpt=descriptor.get_dim_out(), |
| 78 | + mixed_types=descriptor.mixed_types(), |
| 79 | + **data["fitting_net"], |
| 80 | + ) |
| 81 | + else: |
| 82 | + raise ValueError(f"Unknown fitting type {fitting_type}") |
| 83 | + |
| 84 | + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) |
| 85 | + # pairtab |
| 86 | + filepath = data["use_srtab"] |
| 87 | + pt_model = PairTabAtomicModel( |
| 88 | + filepath, |
| 89 | + data["descriptor"]["rcut"], |
| 90 | + data["descriptor"]["sel"], |
| 91 | + type_map=data["type_map"], |
| 92 | + ) |
| 93 | + |
| 94 | + rmin = data["sw_rmin"] |
| 95 | + rmax = data["sw_rmax"] |
| 96 | + atom_exclude_types = data.get("atom_exclude_types", []) |
| 97 | + pair_exclude_types = data.get("pair_exclude_types", []) |
| 98 | + return DPZBLModel( |
| 99 | + dp_model, |
| 100 | + pt_model, |
| 101 | + rmin, |
| 102 | + rmax, |
| 103 | + type_map=data["type_map"], |
| 104 | + atom_exclude_types=atom_exclude_types, |
| 105 | + pair_exclude_types=pair_exclude_types, |
| 106 | + ) |
| 107 | + |
| 108 | + |
58 | 109 | def get_spin_model(data: dict) -> SpinModel: |
59 | 110 | """Get a spin model from a dictionary. |
60 | 111 |
|
@@ -100,6 +151,8 @@ def get_model(data: dict): |
100 | 151 | if model_type == "standard": |
101 | 152 | if "spin" in data: |
102 | 153 | return get_spin_model(data) |
| 154 | + elif "use_srtab" in data: |
| 155 | + return get_zbl_model(data) |
103 | 156 | else: |
104 | 157 | return get_standard_model(data) |
105 | 158 | else: |
|
0 commit comments