Skip to content

Commit b9d4aac

Browse files
xrd: Re-introduced label tensorization
1 parent f0d98a7 commit b9d4aac

File tree

4 files changed

+104
-83
lines changed

4 files changed

+104
-83
lines changed

tests/t_labels.py

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,66 @@
11
from holytools.devtools import Unittest
2-
from xrdpattern.xrd import PowderExperiment
2+
from pymatgen.core import Lattice
3+
4+
from xrdpattern.crystal import CrystalStructure, CrystalBasis
5+
from xrdpattern.xrd import PowderExperiment, XrayInfo
6+
from xrdpattern.xrd.experiment import ExperimentTensor
37

48

59
class TestPowderExperiment(Unittest):
6-
def test_empty(self):
10+
def test_is_empty(self):
11+
12+
# empty is empty
713
empty_experiment = PowderExperiment.make_empty()
814
self.assertTrue(not empty_experiment.is_nonempty())
15+
16+
#nonempty is nonempty
17+
918

19+
def test_has_label(self):
20+
pass
21+
1022

11-
class TestTensorRegions(Unittest):
23+
class TestTensorization(Unittest):
1224
def setUp(self):
13-
raise ValueError(f'Tensor regions currently broken! Restore tensorizaation + tests')
14-
# self.label : PowderExperiment = self.make_example_label()
15-
# self.label_tensor : LabelTensor = self.label.to_tensor()
16-
# self.crystal_structure : CrystalPhase = self.label.powder.phases
25+
self.label : PowderExperiment = self.make_example_label()
26+
self.crystal_structure: CrystalStructure = self.label.phases[0]
27+
self.experiment_tensor : ExperimentTensor = self.label.to_tensordict()
28+
29+
def test_lattice_params(self):
30+
expected = (*self.crystal_structure.lengths, *self.crystal_structure.angles)
31+
actual = self.experiment_tensor.get_lattice_params().tolist()
32+
print(f'Tensor lattice params are {actual}')
33+
for x,y in zip(expected, actual):
34+
self.assertEqual(x, y)
35+
36+
def test_spacegroups(self):
37+
expected = [1.0 if j == self.crystal_structure.spacegroup else 0.0 for j in range(1,231)]
38+
actual = self.experiment_tensor.get_spg_probabilities().tolist()
39+
print(f'Tensor spacegroups probabilities are {actual}; \nSpacegroups tensor length = {len(actual)}')
40+
self.assertEqual(actual, expected, f'Expected: {expected}, Actual: {actual}')
41+
self.assertEqual(len(actual), 230)
42+
43+
def test_wavelength(self):
44+
expected = self.label.xray_info.primary_wavelength
45+
actual = self.experiment_tensor.get_primary_wavelength()
46+
print(f'Tensor primary wavelength is {actual}')
47+
self.assertAlmostEqual(actual, expected)
1748

49+
expected_secondary = self.label.xray_info.secondary_wavelength
50+
actual_secondary = self.experiment_tensor.get_secondary_wavelength()
51+
print(f'Tensor secondary wavelength is {actual_secondary}')
52+
self.assertAlmostEqual(actual_secondary, expected_secondary)
1853

19-
#
20-
# def test_lattice_params(self):
21-
# expected = (*self.crystal_structure.lengths, *self.crystal_structure.angles)
22-
# actual = self.label_tensor.get_lattice_params().tolist()
23-
# print(f'Tensor lattice params are {actual}')
24-
# for x,y in zip(expected, actual):
25-
# self.assertEqual(x, y)
26-
#
27-
#
28-
# def test_atomic_sites(self):
29-
# base = self.crystal_structure.base
30-
# for i, region in enumerate(self.label.atomic_site_regions):
31-
# expected_site = base[i] if i < len(base) else AtomicSite.make_void()
32-
# expected = expected_site.as_list()
33-
# actual = self.label_tensor.get_atomic_site(i).tolist()
34-
# if i == 0:
35-
# print(f'Tensor atomic site is {actual}\n Atomic site length = {len(actual)}')
36-
# # 3 coordinates, 8 scattering params, 1 occupancy
37-
# self.assertEqual(len(actual), 3+8+1)
38-
# for x,y, in zip(expected, actual):
39-
# if x is None:
40-
# self.assertTrue(is_nan(y))
41-
# else:
42-
# self.assertEqual(x,y)
43-
#
44-
# def test_spacegroups(self):
45-
# expected = [1.0 if j == self.crystal_structure.spacegroup else 0.0 for j in range(1,231)]
46-
# actual = self.label_tensor.get_spg_probabilities().tolist()
47-
# print(f'Tensor spacegroups are {actual}; \nSpacegroups tensor length = {len(actual)}')
48-
# self.assertEqual(actual, expected, f'Expected: {expected}, Actual: {actual}')
49-
# self.assertEqual(len(actual), 230)
50-
#
51-
#
52-
# def test_artifacts(self):
53-
# expected = [round(num, 2) for num in self.label.artifacts.as_list()]
54-
# actual = [round(num, 2) for num in self.label_tensor.get_artifacts().tolist()]
55-
# print(f'Tensor artifacts are {actual}')
56-
# self.assertEqual(actual, expected)
57-
#
58-
# def test_total_length(self):
59-
# expected = len(self.label.list_repr)
60-
# # noinspection PyTypeChecker
61-
# actual = len(self.label_tensor)
62-
# print(f'Tensor length is {actual}')
63-
# self.assertEqual(actual, expected)
64-
#
65-
#
66-
# def test_returns_powder_tensors(self):
67-
# tensors = [self.label_tensor.get_atomic_site(0), self.label_tensor.get_lattice_params(), self.label_tensor.get_lattice_params()]
68-
# for t in tensors:
69-
# print(f'Type of {t} is {type(t)}')
70-
# self.assertEqual(t.__class__, LabelTensor)
71-
#
72-
# @staticmethod
73-
# def make_example_label() -> PowderExperiment:
74-
# primitives = Lengths(a=3, b=3, c=3)
75-
# angles = Angles(alpha=90, beta=90, gamma=90)
76-
# base = CrystalBase([AtomicSite.make_void()])
77-
# crystal_structure = CrystalPhase(lengths=primitives, angles=angles, base=base)
78-
# crystal_structure.spacegroup = 120
79-
#
80-
# powder = PowderSample(phases=[crystal_structure], crystallite_size=10.0)
81-
# artifacts = XRayInfo(primary_wavelength=1.54, secondary_wavelength=1.54)
82-
# return PowderExperiment(powder, artifacts, is_simulated=True)
83-
#
54+
@staticmethod
55+
def make_example_label() -> PowderExperiment:
56+
lattice = Lattice.from_parameters(3,3,3,90,90,90)
57+
basis = CrystalBasis.empty()
58+
crystal_structure = CrystalStructure(lattice=lattice, basis=basis)
59+
crystal_structure.spacegroup = 120
8460

85-
def is_nan(value):
86-
return value != value
61+
xray_info = XrayInfo.copper_xray()
62+
return PowderExperiment(phases=[crystal_structure], xray_info=xray_info, crystallite_size_nm=10, temp_K=300)
8763

8864
if __name__ == "__main__":
89-
TestPowderExperiment.execute_all()
65+
# TestPowderExperiment.execute_all()
66+
TestTensorization.execute_all()

xrdpattern/parsing/master.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_float(key: str) -> Optional[float]:
122122

123123
experiment = PowderExperiment.make_empty()
124124
experiment.xray_info = XrayInfo(primary_wavelength=get_float('ALPHA1'), secondary_wavelength=get_float('ALPHA2'))
125-
experiment.temp_in_celcius = get_float('TEMP_CELCIUS')
125+
experiment.temp_K = get_float('TEMP_CELCIUS')
126126

127127
return experiment
128128

xrdpattern/pattern/pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def get_info_as_str(self) -> str:
148148
f'- Crystal: {crystal_data} \n'
149149
f'- Experiment Parameters:\n'
150150
f' - Primary wavelength: {self.powder_experiment.xray_info.primary_wavelength}\n'
151-
f' - Temperature : {self.powder_experiment.temp_in_celcius}\n'
151+
f' - Temperature : {self.powder_experiment.temp_K}\n'
152152
f'- Origin Metadata:\n'
153153
f' - Contributor: {self.metadata.contributor_name}\n'
154154
f' - Institution: {self.metadata.institution}\n'

xrdpattern/xrd/experiment.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from importlib.metadata import version
66
from typing import Optional
77

8+
import torch
9+
from tensordict import TensorDict
10+
811
from xrdpattern.crystal import CrystalStructure
912
from xrdpattern.serialization import JsonDataclass
1013
from xrdpattern.xrd.xray import XrayInfo
@@ -19,8 +22,8 @@ class PowderExperiment(JsonDataclass):
1922
phases: list[CrystalStructure]
2023
xray_info : XrayInfo
2124
is_simulated : bool = False
22-
crystallite_size: Optional[float] = None
23-
temp_in_celcius: Optional[float] = None
25+
crystallite_size_nm: Optional[float] = None
26+
temp_K: Optional[float] = None
2427

2528
def __post_init__(self):
2629
if len(self.phases) == 1:
@@ -33,12 +36,12 @@ def make_empty(cls, is_simulated: bool = False) -> PowderExperiment:
3336

3437
@classmethod
3538
def from_multi_phase(cls, phases : list[CrystalStructure]):
36-
return cls(phases=phases, crystallite_size=None, xray_info=XrayInfo.mk_empty(), is_simulated=False)
39+
return cls(phases=phases, crystallite_size_nm=None, xray_info=XrayInfo.mk_empty(), is_simulated=False)
3740

3841
@classmethod
3942
def from_single_phase(cls, phase : CrystalStructure, crystallite_size : Optional[float] = None, is_simulated : bool = False):
4043
artifacts = XrayInfo.mk_empty()
41-
return cls(phases=[phase], crystallite_size=crystallite_size, xray_info=artifacts, is_simulated=is_simulated)
44+
return cls(phases=[phase], crystallite_size_nm=crystallite_size, xray_info=artifacts, is_simulated=is_simulated)
4245

4346
@classmethod
4447
def from_cif(cls, cif_content : str) -> PowderExperiment:
@@ -83,8 +86,28 @@ def is_nonempty(self) -> bool:
8386
def __eq__(self, other : PowderExperiment):
8487
return self.to_str() == other.to_str()
8588

86-
# def to_tensor(self, dtype : torch.dtype = torch.get_default_dtype(), device : torch.device = torch.get_default_device()) -> LabelTensor:
87-
# return LabelTensor(tensor)
89+
def to_tensordict(self, dtype : torch.dtype = torch.get_default_dtype(),
90+
device : torch.device = torch.get_default_device()) -> ExperimentTensor:
91+
92+
def to_tensor(data):
93+
return torch.tensor(data=data, dtype=dtype, device=device)
94+
95+
if len(self.phases) == 0:
96+
raise ValueError('No phases in the experiment. Cannot convert to TensorDict.')
97+
98+
spg_list = [self.phases[0].spacegroup == j for j in range(1,NUM_SPACEGROUPS+1)]
99+
feature_dict = {
100+
'lengths': to_tensor(self.phases[0].lengths),
101+
'angles' : to_tensor(self.phases[0].angles),
102+
'spg_probabilities' : to_tensor(spg_list),
103+
'crystallite_size' : to_tensor(self.crystallite_size_nm) if self.crystallite_size_nm else None,
104+
'temperature' : to_tensor(self.temp_K) if self.temp_K else None,
105+
'primary_wavelength' : to_tensor(self.xray_info.primary_wavelength) if self.xray_info.primary_wavelength else None,
106+
'secondary_wavelength' : to_tensor(self.xray_info.secondary_wavelength) if self.xray_info.secondary_wavelength else None,
107+
}
108+
td = ExperimentTensor(feature_dict)
109+
return td
110+
88111

89112
@dataclass
90113
class Metadata(JsonDataclass):
@@ -104,5 +127,26 @@ def remove_filename(self):
104127
self.filename = None
105128

106129

130+
class ExperimentTensor(TensorDict):
131+
def get_lattice_params(self):
132+
lengths, angles = self.get('lengths'), self.get('angles')
133+
return torch.cat([lengths, angles], dim=0)
134+
135+
def get_spg_probabilities(self):
136+
return self.get('spg_probabilities')
137+
138+
def get_crystallite_size(self):
139+
return self.get('crystallite_size')
140+
141+
def get_ambient_temperature(self):
142+
return self.get('temperature')
143+
144+
def get_primary_wavelength(self):
145+
return self.get('primary_wavelength')
146+
147+
def get_secondary_wavelength(self):
148+
return self.get('secondary_wavelength')
149+
150+
107151
def get_library_version(library_name : str):
108152
return version(library_name)

0 commit comments

Comments
 (0)