Skip to content

Commit 12ceeb1

Browse files
authored
Fix: Support 3rd and higher order elastic tensor in schema
1 parent 34dcf1c commit 12ceeb1

File tree

3 files changed

+76
-38
lines changed

3 files changed

+76
-38
lines changed

src/atomate2/common/schemas/elastic.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from copy import deepcopy
44
from enum import Enum
5-
from typing import Optional
65

76
import numpy as np
87
from emmet.core.math import Matrix3D, MatrixVoigt
@@ -21,66 +20,67 @@
2120
from typing_extensions import Self
2221

2322
from atomate2 import SETTINGS
23+
from atomate2.common.utils import _recursive_to_list
2424

2525

2626
class DerivedProperties(BaseModel):
2727
"""Properties derived from an elastic tensor."""
2828

29-
k_voigt: Optional[float] = Field(
29+
k_voigt: float | None = Field(
3030
None, description="Voigt average of the bulk modulus."
3131
)
32-
k_reuss: Optional[float] = Field(
32+
k_reuss: float | None = Field(
3333
None, description="Reuss average of the bulk modulus."
3434
)
35-
k_vrh: Optional[float] = Field(
35+
k_vrh: float | None = Field(
3636
None, description="Voigt-Reuss-Hill average of the bulk modulus."
3737
)
38-
g_voigt: Optional[float] = Field(
38+
g_voigt: float | None = Field(
3939
None, description="Voigt average of the shear modulus."
4040
)
41-
g_reuss: Optional[float] = Field(
41+
g_reuss: float | None = Field(
4242
None, description="Reuss average of the shear modulus."
4343
)
44-
g_vrh: Optional[float] = Field(
44+
g_vrh: float | None = Field(
4545
None, description="Voigt-Reuss-Hill average of the shear modulus."
4646
)
47-
universal_anisotropy: Optional[float] = Field(
47+
universal_anisotropy: float | None = Field(
4848
None, description="Universal elastic anisotropy."
4949
)
50-
homogeneous_poisson: Optional[float] = Field(
50+
homogeneous_poisson: float | None = Field(
5151
None, description="Homogeneous poisson ratio."
5252
)
53-
y_mod: Optional[float] = Field(
53+
y_mod: float | None = Field(
5454
None,
5555
description="Young's modulus (SI units) from the Voight-Reuss-Hill averages of "
5656
"the bulk and shear moduli.",
5757
)
58-
trans_v: Optional[float] = Field(
58+
trans_v: float | None = Field(
5959
None,
6060
description="Transverse sound velocity (SI units) obtained from the "
6161
"Voigt-Reuss-Hill average bulk modulus.",
6262
)
63-
long_v: Optional[float] = Field(
63+
long_v: float | None = Field(
6464
None,
6565
description="Longitudinal sound velocity (SI units) obtained from the "
6666
"Voigt-Reuss-Hill average bulk modulus.",
6767
)
68-
snyder_ac: Optional[float] = Field(
68+
snyder_ac: float | None = Field(
6969
None, description="Synder's acoustic sound velocity (SI units)."
7070
)
71-
snyder_opt: Optional[float] = Field(
71+
snyder_opt: float | None = Field(
7272
None, description="Synder's optical sound velocity (SI units)."
7373
)
74-
snyder_total: Optional[float] = Field(
74+
snyder_total: float | None = Field(
7575
None, description="Synder's total sound velocity (SI units)."
7676
)
77-
clark_thermalcond: Optional[float] = Field(
77+
clark_thermalcond: float | None = Field(
7878
None, description="Clarke's thermal conductivity (SI units)."
7979
)
80-
cahill_thermalcond: Optional[float] = Field(
80+
cahill_thermalcond: float | None = Field(
8181
None, description="Cahill's thermal conductivity (SI units)."
8282
)
83-
debye_temperature: Optional[float] = Field(
83+
debye_temperature: float | None = Field(
8484
None,
8585
description="Debye temperature from longitudinal and transverse sound "
8686
"velocities (SI units).",
@@ -90,34 +90,34 @@ class DerivedProperties(BaseModel):
9090
class FittingData(BaseModel):
9191
"""Data used to fit elastic tensors."""
9292

93-
cauchy_stresses: Optional[list[Matrix3D]] = Field(
93+
cauchy_stresses: list[Matrix3D] | None = Field(
9494
None, description="The Cauchy stresses used to fit the elastic tensor."
9595
)
96-
strains: Optional[list[Matrix3D]] = Field(
96+
strains: list[Matrix3D] | None = Field(
9797
None, description="The strains used to fit the elastic tensor."
9898
)
99-
pk_stresses: Optional[list[Matrix3D]] = Field(
99+
pk_stresses: list[Matrix3D] | None = Field(
100100
None, description="The Piola-Kirchoff stresses used to fit the elastic tensor."
101101
)
102-
deformations: Optional[list[Matrix3D]] = Field(
102+
deformations: list[Matrix3D] | None = Field(
103103
None, description="The deformations corresponding to each strain state."
104104
)
105-
uuids: Optional[list[str]] = Field(
105+
uuids: list[str] | None = Field(
106106
None, description="The uuids of the deformation jobs."
107107
)
108-
job_dirs: Optional[list[Optional[str]]] = Field(
108+
job_dirs: list[str | None] | None = Field(
109109
None, description="The directories where the deformation jobs were run."
110110
)
111-
failed_uuids: Optional[list[str]] = Field(
111+
failed_uuids: list[str] | None = Field(
112112
None, description="The uuids of perturbations that were not completed"
113113
)
114114

115115

116116
class ElasticTensorDocument(BaseModel):
117117
"""Raw and standardized elastic tensors."""
118118

119-
raw: Optional[MatrixVoigt] = Field(None, description="Raw elastic tensor.")
120-
ieee_format: Optional[MatrixVoigt] = Field(
119+
raw: MatrixVoigt | list | None = Field(None, description="Raw elastic tensor.")
120+
ieee_format: MatrixVoigt | list | None = Field(
121121
None, description="Elastic tensor in IEEE format."
122122
)
123123

@@ -131,28 +131,28 @@ class ElasticWarnings(Enum):
131131
class ElasticDocument(StructureMetadata):
132132
"""Document containing elastic tensor information and related properties."""
133133

134-
structure: Optional[Structure] = Field(
134+
structure: Structure | None = Field(
135135
None, description="The structure for which the elastic data is calculated."
136136
)
137-
elastic_tensor: Optional[ElasticTensorDocument] = Field(
137+
elastic_tensor: ElasticTensorDocument | None = Field(
138138
None, description="Fitted elastic tensor."
139139
)
140-
eq_stress: Optional[Matrix3D] = Field(
140+
eq_stress: Matrix3D | None = Field(
141141
None, description="The equilibrium stress of the structure."
142142
)
143-
derived_properties: Optional[DerivedProperties] = Field(
143+
derived_properties: DerivedProperties | None = Field(
144144
None, description="Properties derived from the elastic tensor."
145145
)
146-
fitting_data: Optional[FittingData] = Field(
146+
fitting_data: FittingData | None = Field(
147147
None, description="Data used to fit the elastic tensor."
148148
)
149-
fitting_method: Optional[str] = Field(
149+
fitting_method: str | None = Field(
150150
None, description="Method used to fit the elastic tensor."
151151
)
152-
order: Optional[int] = Field(
152+
order: int | None = Field(
153153
None, description="Order of the expansion of the elastic tensor."
154154
)
155-
warnings: Optional[list[str]] = Field(None, description="Warnings.")
155+
warnings: list[str] | None = Field(None, description="Warnings.")
156156

157157
@classmethod
158158
def from_stresses(
@@ -163,8 +163,8 @@ def from_stresses(
163163
uuids: list[str],
164164
job_dirs: list[str],
165165
fitting_method: str = SETTINGS.ELASTIC_FITTING_METHOD,
166-
order: Optional[int] = None,
167-
equilibrium_stress: Optional[Matrix3D] = None,
166+
order: int | None = None,
167+
equilibrium_stress: Matrix3D | None = None,
168168
symprec: float = SETTINGS.SYMPREC,
169169
allow_elastically_unstable_structs: bool = True,
170170
failed_uuids: list[str] = None,
@@ -264,7 +264,8 @@ def from_stresses(
264264
fitting_method=fitting_method,
265265
order=order,
266266
elastic_tensor=ElasticTensorDocument(
267-
raw=result.voigt.tolist(), ieee_format=ieee.voigt.tolist()
267+
raw=_recursive_to_list(result.voigt),
268+
ieee_format=_recursive_to_list(ieee.voigt),
268269
),
269270
fitting_data=FittingData(
270271
cauchy_stresses=[s.tolist() for s in stresses],

src/atomate2/common/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,15 @@ def _recursive_get_dir_names(jobs: list, dir_names: list) -> None:
209209
_recursive_get_dir_names(sub_jobs, dir_names)
210210
else:
211211
dir_names.append(a_job.output.dir_name)
212+
213+
214+
def _recursive_to_list(voigt_data: Any) -> Any:
215+
"""Recursively convert tensor-like data to nested lists.
216+
217+
Useful for converting numpy or torch arrays to lists.
218+
"""
219+
if isinstance(voigt_data, list):
220+
return [_recursive_to_list(item) for item in voigt_data]
221+
if hasattr(voigt_data, "tolist"):
222+
return voigt_data.tolist()
223+
return voigt_data

tests/common/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Test utility functions.
2+
3+
TODO: there should be tests for the other utility functions.
4+
Not high priority but good for long term.
5+
"""
6+
7+
import numpy as np
8+
9+
from atomate2.common.utils import _recursive_to_list
10+
11+
try:
12+
import torch
13+
except ImportError:
14+
torch = None
15+
16+
17+
def test_to_list():
18+
as_list = [[1, 2, 3, 4, 5], [10, 9, 8, 7, 6], [3, 5, 7, 9, 11]]
19+
arr = np.array(as_list)
20+
21+
assert _recursive_to_list(arr) == as_list
22+
if torch is not None:
23+
assert _recursive_to_list(torch.from_numpy(arr)) == as_list
24+
for obj in (None, 1, 1.5):
25+
assert _recursive_to_list(obj) == obj

0 commit comments

Comments
 (0)