Skip to content

Commit 28e9ade

Browse files
crystal: Re-introduced phase fractions attribute
1 parent 1c17f49 commit 28e9ade

File tree

8 files changed

+132
-125
lines changed

8 files changed

+132
-125
lines changed

exports.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
@classmethod
2+
def make_empty(cls, is_simulated: bool = False, num_phases: int = 1) -> PowderExperiment:
3+
phases = []
4+
for j in range(num_phases):
5+
lengths = (float('nan'), float('nan'), float('nan'))
6+
angles = (float('nan'), float('nan'), float('nan'))
7+
base = CrystalBasis.empty()
8+
9+
p = CrystalStructure(lengths=lengths, angles=angles, basis=base)
10+
phases.append(p)
11+
12+
xray_info = XrayInfo.mk_empty()
13+
return cls(phases=phases, crystallite_size=None, temp_in_celcius=None, xray_info=xray_info,
14+
is_simulated=is_simulated)
15+
16+
17+
def get_list_repr(self) -> list:
18+
list_repr = []
19+
structure = self.phases[0]
20+
21+
a, b, c = structure.lengths
22+
alpha, beta, gamma = structure.angles
23+
lattice_params = [a, b, c, alpha, beta, gamma]
24+
list_repr += lattice_params
25+
26+
base = structure.basis
27+
padded_base = self.get_padded_base(base=base, nan_padding=base.is_empty())
28+
for atomic_site in padded_base:
29+
list_repr += atomic_site.as_list()
30+
31+
if structure.spacegroup is None:
32+
spg_logits_list = [float('nan') for _ in range(NUM_SPACEGROUPS)]
33+
else:
34+
spg_logits_list = [1000 if j + 1 == structure.spacegroup else 0 for j in range(NUM_SPACEGROUPS)]
35+
list_repr += spg_logits_list
36+
37+
list_repr += self.xray_info.as_list()
38+
list_repr += [self.is_simulated]
39+
40+
return list_repr
41+
42+
@staticmethod
43+
def get_padded_base(base: CrystalBasis, nan_padding : bool) -> CrystalBasis:
44+
def make_padding_site():
45+
if nan_padding:
46+
site = AtomSite.make_placeholder()
47+
else:
48+
site = AtomSite.make_void()
49+
return site
50+
51+
delta = MAX_ATOMIC_SITES - len(base)
52+
if delta < 0:
53+
raise ValueError(f'Base is too large! Size = {len(base)} exceeds MAX_ATOMIC_SITES = {MAX_ATOMIC_SITES}')
54+
55+
padded_base = base + [make_padding_site() for _ in range(delta)]
56+
return padded_base

tests/t_crystal/t_properties.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pymatgen.core import Lattice
2+
13
import tests.t_crystal.base_crystal as BaseTest
24
from xrdpattern.crystal import CrystalStructure, CrystalBasis
35

@@ -8,15 +10,15 @@ class TestCifParsing(BaseTest.CrystalTest):
810
def test_lattice_parameters(self):
911
expected_lengths = [(5.801, 11.272, 5.57), (4.0809,4.0809,4.0809)]
1012
for crystal, (a_exp, b_exp, c_exp) in zip(self.custom_structures, expected_lengths):
11-
a, b, c = crystal.lengths
13+
a, b, c = crystal.lattice.lengths
1214
self.assertAlmostEqual(a, a_exp, places=3)
1315
self.assertAlmostEqual(b, b_exp, places=4)
1416
self.assertAlmostEqual(c, c_exp, places=3)
1517

1618
def test_angles(self):
1719
expected_angles = [(90, 90, 90), (89.676, 89.676, 89.676)]
1820
for crystal, (alpha_exp, beta_exp, gamma_exp) in zip(self.custom_structures, expected_angles):
19-
alpha, beta, gamma = crystal.angles
21+
alpha, beta, gamma = crystal.lattice.angles
2022
self.assertEqual(alpha, alpha_exp)
2123
self.assertAlmostEqual(beta, beta_exp, places=3)
2224
self.assertEqual(gamma, gamma_exp)
@@ -34,7 +36,8 @@ def test_to_cif(self):
3436
print(f'CIF = \n{cif}')
3537

3638
def test_standardize(self):
37-
phase = CrystalStructure(lengths=(5.801, 11.272, 5.57), angles=(90, 90, 90), basis=CrystalBasis.empty())
39+
lattice = Lattice.from_parameters(5.801, 11.272, 5.57, 90, 90, 90)
40+
phase = CrystalStructure(lattice=lattice, basis=CrystalBasis.empty())
3841
new_phase = phase.get_standardized()
3942

4043
self.assertTrue(len(new_phase.basis) == 0)

xrdpattern/crystal/components/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
import math
43
import json
5-
from typing import Optional, Iterable
6-
from xrdpattern.serialization import Serializable
4+
import math
5+
from typing import Iterable
76

7+
from xrdpattern.serialization import Serializable
88
from .atomic_site import AtomSite
99

1010

xrdpattern/crystal/components/crystal.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from distlib.util import cached_property
99
from pymatgen.core import Structure, Lattice, Species, Element
1010
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
11-
from pymatgen.symmetry.groups import SpaceGroup
1211

1312
from xrdpattern.serialization import JsonDataclass
1413
from .atomic_site import AtomSite
1514
from .base import CrystalBasis
1615

17-
1816
logger = logging.getLogger(__name__)
1917
CrystalSystem = Literal["cubic", "hexagonal", "monoclinic", "orthorhombic", "tetragonal", "triclinic", "trigonal"]
2018
# ---------------------------------------------------------
@@ -26,6 +24,11 @@ class CrystalStructure(JsonDataclass):
2624
spacegroup : Optional[int] = None
2725
chemical_composition : Optional[str] = None
2826
wyckoff_symbols : Optional[list[str]] = None
27+
phase_fraction: Optional[float] = None
28+
29+
def __post_init__(self):
30+
if not 0 <= self.phase_fraction <= 1:
31+
raise ValueError(f'Phase fraction must be between 0 and 1. Got {self.phase_fraction}')
2932

3033
@classmethod
3134
def from_cif(cls, cif_content : str) -> CrystalStructure:
@@ -137,4 +140,12 @@ def _spg_to_crystal_system(spg: int) -> CrystalSystem:
137140
return "trigonal"
138141
if spg <= 194:
139142
return "hexagonal"
140-
return "cubic"
143+
return "cubic"
144+
145+
@property
146+
def angles(self) -> tuple[float, float, float]:
147+
return self.lattice.angles
148+
149+
@property
150+
def lengths(self) -> tuple[float, float, float]:
151+
return self.lattice.lengths

xrdpattern/pattern/pattern.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def __eq__(self, other : XrdPattern):
103103
if isinstance(v1, np.ndarray):
104104
is_ok = np.array_equal(v1, v2)
105105
elif isinstance(v1, PowderExperiment):
106-
objs_equal = [str(x)==str(y) for x,y in zip(v1.get_list_repr(), v2.get_list_repr())]
107-
is_ok = all(objs_equal)
106+
is_ok = v1
108107
else:
109108
is_ok = v1 == v2
110109

xrdpattern/xrd/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def has_label(self, label_type: LabelType) -> bool:
8686
if label_type == LabelType.composition:
8787
return self.primary_phase.chemical_composition is not None
8888
if label_type == LabelType.lattice:
89-
return all(not math.isnan(x) for x in self.primary_phase.lengths) and all(not math.isnan(x) for x in self.primary_phase.angles)
89+
return
9090
if label_type == LabelType.atom_coords:
9191
return len(self.primary_phase.basis) > 0
9292
if label_type == LabelType.spg:

xrdpattern/xrd/experiment.py

Lines changed: 9 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22

33
import math
44
from dataclasses import dataclass, field
5-
from enum import Enum
65
from importlib.metadata import version
76
from typing import Optional
87

9-
import torch
10-
11-
from xrdpattern.crystal import CrystalStructure, CrystalBasis, AtomSite
8+
from xrdpattern.crystal import CrystalStructure
129
from xrdpattern.serialization import JsonDataclass
13-
from xrdpattern.xrd.tensorization import LabelTensor
10+
from xrdpattern.xrd.xray import XrayInfo
1411

1512
NUM_SPACEGROUPS = 230
1613
MAX_ATOMIC_SITES = 100
@@ -26,25 +23,13 @@ class PowderExperiment(JsonDataclass):
2623
temp_in_celcius: Optional[float] = None
2724

2825
def __post_init__(self):
29-
if len(self.phases) == 0:
30-
raise ValueError(f'Material must have at least one phase! Got {len(self.phases)}')
31-
3226
if len(self.phases) == 1:
3327
self.phases[0].phase_fraction = 1
3428

3529
@classmethod
36-
def make_empty(cls, is_simulated : bool = False, num_phases : int = 1) -> PowderExperiment:
37-
phases = []
38-
for j in range(num_phases):
39-
lengths = (float('nan'),float('nan'), float('nan'))
40-
angles = (float('nan'),float('nan'), float('nan'))
41-
base = CrystalBasis.empty()
42-
43-
p = CrystalStructure(lengths=lengths, angles=angles, basis=base)
44-
phases.append(p)
45-
30+
def make_empty(cls, is_simulated: bool = False) -> PowderExperiment:
4631
xray_info = XrayInfo.mk_empty()
47-
return cls(phases=phases, crystallite_size=None, temp_in_celcius=None, xray_info=xray_info, is_simulated=is_simulated)
32+
return cls(phases=[], xray_info=xray_info, is_simulated=is_simulated)
4833

4934
@classmethod
5035
def from_multi_phase(cls, phases : list[CrystalStructure]):
@@ -93,101 +78,11 @@ def is_nonempty(self) -> bool:
9378
crystal_basis_nonempty = len(primary_phase.basis) > 0
9479
return xray_info_nonemtpy or composition_nonempty or lattice_params_nonempty or crystal_basis_nonempty
9580

96-
@property
97-
def primary_wavelength(self) -> float:
98-
return self.xray_info.primary_wavelength
99-
100-
@property
101-
def secondary_wavelength(self) -> float:
102-
return self.xray_info.secondary_wavelength
103-
104-
def get_list_repr(self) -> list:
105-
list_repr = []
106-
structure = self.phases[0]
107-
108-
a, b, c = structure.lengths
109-
alpha, beta, gamma = structure.angles
110-
lattice_params = [a, b, c, alpha, beta, gamma]
111-
list_repr += lattice_params
112-
113-
base = structure.basis
114-
padded_base = self.get_padded_base(base=base, nan_padding=base.is_empty())
115-
for atomic_site in padded_base:
116-
list_repr += atomic_site.as_list()
117-
118-
if structure.spacegroup is None:
119-
spg_logits_list = [float('nan') for _ in range(NUM_SPACEGROUPS)]
120-
else:
121-
spg_logits_list = [1000 if j + 1 == structure.spacegroup else 0 for j in range(NUM_SPACEGROUPS)]
122-
list_repr += spg_logits_list
123-
124-
list_repr += self.xray_info.as_list()
125-
list_repr += [self.is_simulated]
126-
127-
return list_repr
128-
129-
@staticmethod
130-
def get_padded_base(base: CrystalBasis, nan_padding : bool) -> CrystalBasis:
131-
def make_padding_site():
132-
if nan_padding:
133-
site = AtomSite.make_placeholder()
134-
else:
135-
site = AtomSite.make_void()
136-
return site
137-
138-
delta = MAX_ATOMIC_SITES - len(base)
139-
if delta < 0:
140-
raise ValueError(f'Base is too large! Size = {len(base)} exceeds MAX_ATOMIC_SITES = {MAX_ATOMIC_SITES}')
141-
142-
padded_base = base + [make_padding_site() for _ in range(delta)]
143-
return padded_base
144-
145-
146-
def to_tensor(self, dtype : torch.dtype = torch.get_default_dtype(), device : torch.device = torch.get_default_device()) -> LabelTensor:
147-
tensor = torch.tensor(self.get_list_repr(), dtype=dtype, device=device)
148-
return LabelTensor(tensor)
149-
150-
151-
@dataclass
152-
class XrayInfo(JsonDataclass):
153-
primary_wavelength: Optional[float]
154-
secondary_wavelength: Optional[float]
155-
156-
@classmethod
157-
def mk_empty(cls):
158-
return cls(primary_wavelength=None, secondary_wavelength=None)
159-
160-
def as_list(self) -> list[float]:
161-
return [self.primary_wavelength, self.secondary_wavelength]
162-
163-
@staticmethod
164-
def default_ratio() -> float:
165-
return 0.5
166-
167-
168-
class XrdAnode(Enum):
169-
Cu = "Cu"
170-
Mo = "Mo"
171-
Cr = "Cr"
172-
Fe = "Fe"
173-
Co = "Co"
174-
Ag = "Ag"
175-
176-
def get_wavelengths(self) -> (float, float):
177-
MATERiAL_TO_WAVELENGTHS = {
178-
"Cu": (1.54439, 1.54056),
179-
"Mo": (0.71359, 0.70930),
180-
"Cr": (2.29361, 2.28970),
181-
"Fe": (1.93998, 1.93604),
182-
"Co": (1.79285, 1.78896),
183-
"Ag": (0.563813, 0.559421),
184-
}
185-
return MATERiAL_TO_WAVELENGTHS[self.value]
186-
187-
def get_xray_info(self) -> XrayInfo:
188-
primary, secondary = self.get_wavelengths()
189-
return XrayInfo(primary_wavelength=primary, secondary_wavelength=secondary)
81+
def __eq__(self, other : PowderExperiment):
82+
return self.to_str() == other.to_str()
19083

84+
# def to_tensor(self, dtype : torch.dtype = torch.get_default_dtype(), device : torch.device = torch.get_default_device()) -> LabelTensor:
85+
# return LabelTensor(tensor)
19186

19287
@dataclass
19388
class Metadata(JsonDataclass):
@@ -206,5 +101,6 @@ def __eq__(self, other : Metadata):
206101
def remove_filename(self):
207102
self.filename = None
208103

104+
209105
def get_library_version(library_name : str):
210106
return version(library_name)

xrdpattern/xrd/xray.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
from xrdpattern.serialization import JsonDataclass
5+
6+
7+
@dataclass
8+
class XrayInfo(JsonDataclass):
9+
primary_wavelength: Optional[float]
10+
secondary_wavelength: Optional[float]
11+
12+
@classmethod
13+
def from_anode(cls, element : str):
14+
MATERIAL_TO_WAVELNGTHS = {
15+
"Cu": (1.54439, 1.54056),
16+
"Mo": (0.71359, 0.70930),
17+
"Cr": (2.29361, 2.28970),
18+
"Fe": (1.93998, 1.93604),
19+
"Co": (1.79285, 1.78896),
20+
"Ag": (0.563813, 0.559421),
21+
}
22+
return cls(*MATERIAL_TO_WAVELNGTHS[element])
23+
24+
@classmethod
25+
def mk_empty(cls):
26+
return cls(primary_wavelength=None, secondary_wavelength=None)
27+
28+
def as_list(self) -> list[float]:
29+
return [self.primary_wavelength, self.secondary_wavelength]
30+
31+
@staticmethod
32+
def default_ratio() -> float:
33+
return 0.5
34+
35+
36+
class XrdAnode:
37+
Cu = "Cu"
38+
Mo = "Mo"
39+
Cr = "Cr"
40+
Fe = "Fe"
41+
Co = "Co"
42+
Ag = "Ag"

0 commit comments

Comments
 (0)