|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import TYPE_CHECKING |
6 | | - |
7 | | -import numpy as np |
8 | 5 | from pydantic import ( |
9 | | - BaseModel, |
10 | 6 | Field, |
11 | 7 | field_validator, |
12 | | - model_serializer, |
13 | 8 | ) |
14 | 9 | from pymatgen.analysis.elasticity import ElasticTensor |
15 | | -from pymatgen.core import Element, Structure |
16 | 10 |
|
17 | 11 | from emmet.core.elasticity import ( |
18 | 12 | BulkModulus, |
19 | 13 | ElasticityDoc, |
20 | 14 | ElasticTensorDoc, |
21 | 15 | ShearModulus, |
22 | 16 | ) |
23 | | -from emmet.core.math import Matrix3D, Vector3D, Vector6D, matrix_3x3_to_voigt |
24 | | -from emmet.core.structure import StructureMetadata |
25 | | -from emmet.core.tasks import TaskDoc |
26 | | -from emmet.core.types.pymatgen_types.composition_adapter import CompositionType |
27 | 17 | from emmet.core.types.pymatgen_types.structure_adapter import StructureType |
28 | | -from emmet.core.types.typing import IdentifierType |
29 | | -from emmet.core.utils import jsanitize |
30 | | -from emmet.core.vasp.calc_types import RunType as VaspRunType |
31 | | - |
32 | | -if TYPE_CHECKING: |
33 | | - from collections.abc import Sequence |
34 | | - |
35 | | - from typing_extensions import Self |
36 | 18 |
|
37 | 19 |
|
38 | 20 | class MLDoc(ElasticityDoc): |
@@ -135,229 +117,3 @@ def shear_vrh_no_suffix(cls, new_key, values): |
135 | 117 | """Map field shear_modulus_vrh to shear_modulus.""" |
136 | 118 | val = values.get("shear_modulus_vrh", new_key) |
137 | 119 | return ShearModulus(vrh=val) |
138 | | - |
139 | | - |
140 | | -class MLTrainDoc(StructureMetadata): |
141 | | - """Generic schema for ML training data.""" |
142 | | - |
143 | | - structure: StructureType | None = Field( |
144 | | - None, description="Structure for this entry." |
145 | | - ) |
146 | | - |
147 | | - energy: float | None = Field( |
148 | | - None, description="The total energy associated with this structure." |
149 | | - ) |
150 | | - |
151 | | - forces: list[Vector3D] | None = Field( |
152 | | - None, |
153 | | - description="The interatomic forces corresponding to each site in the structure.", |
154 | | - ) |
155 | | - |
156 | | - abs_forces: list[float] | None = Field( |
157 | | - None, description="The magnitude of the interatomic force on each site." |
158 | | - ) |
159 | | - |
160 | | - stress: Vector6D | None = Field( |
161 | | - None, |
162 | | - description="The components of the symmetric stress tensor in Voigt notation (xx, yy, zz, yz, xz, xy).", |
163 | | - ) |
164 | | - |
165 | | - stress_matrix: Matrix3D | None = Field( |
166 | | - None, |
167 | | - description="The 3x3 stress tensor. Use this if the tensor is unphysically non-symmetric.", |
168 | | - ) |
169 | | - |
170 | | - bandgap: float | None = Field(None, description="The final DFT bandgap.") |
171 | | - |
172 | | - elements: list[Element] | None = Field( |
173 | | - None, |
174 | | - description="List of unique elements in the material sorted alphabetically.", |
175 | | - ) |
176 | | - |
177 | | - composition: CompositionType | None = Field( |
178 | | - None, description="Full composition for the material." |
179 | | - ) |
180 | | - |
181 | | - composition_reduced: CompositionType | None = Field( |
182 | | - None, |
183 | | - title="Reduced Composition", |
184 | | - description="Simplified representation of the composition.", |
185 | | - ) |
186 | | - |
187 | | - functional: VaspRunType | None = Field( |
188 | | - None, description="The approximate functional used to generate this entry." |
189 | | - ) |
190 | | - |
191 | | - bader_charges: list[float] | None = Field( |
192 | | - None, description="Bader charges on each site of the structure." |
193 | | - ) |
194 | | - bader_magmoms: list[float] | None = Field( |
195 | | - None, |
196 | | - description="Bader on-site magnetic moments for each site of the structure.", |
197 | | - ) |
198 | | - |
199 | | - @model_serializer |
200 | | - def deseralize(self): |
201 | | - """Ensure output is JSON compliant.""" |
202 | | - return jsanitize( |
203 | | - {k: getattr(self, k, None) for k in self.__class__.model_fields} |
204 | | - ) |
205 | | - |
206 | | - @classmethod |
207 | | - def from_structure( |
208 | | - cls, |
209 | | - meta_structure: Structure, |
210 | | - fields: list[str] | None = None, |
211 | | - **kwargs, |
212 | | - ) -> Self: |
213 | | - """ |
214 | | - Create an ML training document from a structure and fields. |
215 | | -
|
216 | | - This method mostly exists to ensure that the structure field is |
217 | | - set because `meta_structure` does not populate it automatically. |
218 | | -
|
219 | | - Parameters |
220 | | - ----------- |
221 | | - meta_structure : Structure |
222 | | - fields : list of str or None |
223 | | - Additional fields in the document to populate |
224 | | - **kwargs |
225 | | - Any other fields / constructor kwargs |
226 | | - """ |
227 | | - if (forces := kwargs.get("forces")) is not None and kwargs.get( |
228 | | - "abs_forces" |
229 | | - ) is None: |
230 | | - kwargs["abs_forces"] = [np.linalg.norm(f) for f in forces] |
231 | | - |
232 | | - return super().from_structure( |
233 | | - meta_structure=meta_structure, |
234 | | - fields=fields, |
235 | | - structure=meta_structure, |
236 | | - **kwargs, |
237 | | - ) |
238 | | - |
239 | | - @classmethod |
240 | | - def from_task_doc( |
241 | | - cls, |
242 | | - task_doc: TaskDoc, |
243 | | - **kwargs, |
244 | | - ) -> list[Self]: |
245 | | - """Create a list of ML training documents from the ionic steps in a TaskDoc. |
246 | | -
|
247 | | - Parameters |
248 | | - ----------- |
249 | | - task_doc : TaskDoc |
250 | | - **kwargs |
251 | | - Any kwargs to pass to `from_structure`. |
252 | | - """ |
253 | | - entries = [] |
254 | | - |
255 | | - for cr in task_doc.calcs_reversed[::-1]: |
256 | | - nion = len(cr.output.ionic_steps) |
257 | | - |
258 | | - for iion, ionic_step in enumerate(cr.output.ionic_steps): |
259 | | - structure = Structure.from_dict(ionic_step.structure.as_dict()) |
260 | | - # these are fields that should only be set on the final frame of a calculation |
261 | | - # also patch in magmoms because of how Calculation works |
262 | | - last_step_kwargs = {} |
263 | | - if iion == nion - 1: |
264 | | - if magmom := cr.output.structure.site_properties.get("magmom"): |
265 | | - structure.add_site_property("magmom", magmom) |
266 | | - last_step_kwargs["bandgap"] = cr.output.bandgap |
267 | | - if bader_analysis := cr.bader: |
268 | | - for bk in ( |
269 | | - "charge", |
270 | | - "magmom", |
271 | | - ): |
272 | | - last_step_kwargs[f"bader_{bk}s"] = bader_analysis[bk] |
273 | | - |
274 | | - if (_st := ionic_step.stress) is not None: |
275 | | - st = np.array(_st) |
276 | | - if np.allclose(st, st.T, rtol=1e-8): |
277 | | - # Stress tensor is symmetric |
278 | | - last_step_kwargs["stress"] = matrix_3x3_to_voigt(_st) |
279 | | - else: |
280 | | - # Stress tensor is non-symmetric |
281 | | - last_step_kwargs["stress_matrix"] = _st |
282 | | - |
283 | | - entries.append( |
284 | | - cls.from_structure( |
285 | | - meta_structure=structure, |
286 | | - energy=ionic_step.e_0_energy, |
287 | | - forces=ionic_step.forces, |
288 | | - functional=cr.run_type, |
289 | | - **last_step_kwargs, |
290 | | - **kwargs, |
291 | | - ) |
292 | | - ) |
293 | | - return entries |
294 | | - |
295 | | - |
296 | | -class MatPESProvenanceDoc(BaseModel): |
297 | | - """Information regarding the origins of a MatPES structure.""" |
298 | | - |
299 | | - original_mp_id: IdentifierType | None = Field( |
300 | | - None, |
301 | | - description="MP identifier corresponding to the Materials Project structure from which this entry was sourced from.", |
302 | | - ) |
303 | | - materials_project_version: str | None = Field( |
304 | | - None, |
305 | | - description="The version of the Materials Project from which the struture was sourced.", |
306 | | - ) |
307 | | - md_ensemble: str | None = Field( |
308 | | - None, |
309 | | - description="The molecular dynamics ensemble used to generate this structure.", |
310 | | - ) |
311 | | - md_temperature: float | None = Field( |
312 | | - None, |
313 | | - description="If a float, the temperature in Kelvin at which MLMD was performed.", |
314 | | - ) |
315 | | - md_pressure: float | None = Field( |
316 | | - None, |
317 | | - description="If a float, the pressure in atmosphere at which MLMD was performed.", |
318 | | - ) |
319 | | - md_step: int | None = Field( |
320 | | - None, |
321 | | - description="The step in the MD simulation from which the structure was sampled.", |
322 | | - ) |
323 | | - mlip_name: str | None = Field( |
324 | | - None, description="The name of the ML potential used to perform MLMD." |
325 | | - ) |
326 | | - |
327 | | - |
328 | | -class MatPESTrainDoc(MLTrainDoc): |
329 | | - """ |
330 | | - Schema for VASP data in the Materials Potential Energy Surface (MatPES) effort. |
331 | | -
|
332 | | - This schema is used in the data entries for MatPES v2025.1, |
333 | | - which can be downloaded either: |
334 | | - - On [MPContribs](https://materialsproject-contribs.s3.amazonaws.com/index.html#MatPES_2025_1/) |
335 | | - - or on [the site] |
336 | | - """ |
337 | | - |
338 | | - matpes_id: str | None = Field(None, description="MatPES identifier.") |
339 | | - |
340 | | - formation_energy_per_atom: float | None = Field( |
341 | | - None, |
342 | | - description="The uncorrected formation enthalpy per atom at zero pressure and temperature.", |
343 | | - ) |
344 | | - cohesive_energy_per_atom: float | None = Field( |
345 | | - None, description="The uncorrected cohesive energy per atom." |
346 | | - ) |
347 | | - |
348 | | - provenance: MatPESProvenanceDoc | None = Field( |
349 | | - None, description="Information about the provenance of the structure." |
350 | | - ) |
351 | | - |
352 | | - @property |
353 | | - def pressure(self) -> float | None: |
354 | | - """Return the pressure from the DFT stress tensor.""" |
355 | | - return sum(self.stress[:3]) / 3.0 if self.stress else None |
356 | | - |
357 | | - @property |
358 | | - def magmoms(self) -> Sequence[float] | None: |
359 | | - """Retrieve on-site magnetic moments from the structure if they exist.""" |
360 | | - magmom = ( |
361 | | - self.structure.site_properties.get("magmom") if self.structure else None |
362 | | - ) |
363 | | - return magmom |
0 commit comments