Skip to content

Commit bea4b79

Browse files
xrd: Re-introduced has_label/is_labeled in PowderExperiment
1 parent b9d4aac commit bea4b79

File tree

6 files changed

+107
-122
lines changed

6 files changed

+107
-122
lines changed

tests/t_labels.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,28 @@
33

44
from xrdpattern.crystal import CrystalStructure, CrystalBasis
55
from xrdpattern.xrd import PowderExperiment, XrayInfo
6-
from xrdpattern.xrd.experiment import ExperimentTensor
7-
6+
from xrdpattern.xrd.experiment import ExperimentTensor, LabelType
7+
from xrdpattern.crystal.examples import CrystalExamples
88

99
class TestPowderExperiment(Unittest):
10+
def setUp(self):
11+
cif_content = CrystalExamples.get_cif_content(num=1)
12+
self.empty_experiment : PowderExperiment = PowderExperiment.make_empty()
13+
self.full_experiment : PowderExperiment = PowderExperiment.from_cif(cif_content=cif_content)
14+
self.full_experiment.primary_phase.calculate_properties()
15+
1016
def test_is_empty(self):
11-
12-
# empty is empty
1317
empty_experiment = PowderExperiment.make_empty()
14-
self.assertTrue(not empty_experiment.is_nonempty())
15-
16-
#nonempty is nonempty
17-
18+
self.assertTrue(not empty_experiment.is_labeled())
19+
self.assertTrue(self.full_experiment.is_labeled())
1820

1921
def test_has_label(self):
20-
pass
21-
22+
self.assertTrue(not self.empty_experiment.has_label(label_type=LabelType.lattice))
23+
24+
self.assertTrue(self.full_experiment.has_label(label_type=LabelType.lattice))
25+
self.assertTrue(self.full_experiment.has_label(label_type=LabelType.atom_coords))
26+
self.assertTrue(self.full_experiment.has_label(label_type=LabelType.spg))
27+
2228

2329
class TestTensorization(Unittest):
2430
def setUp(self):
@@ -62,5 +68,5 @@ def make_example_label() -> PowderExperiment:
6268
return PowderExperiment(phases=[crystal_structure], xray_info=xray_info, crystallite_size_nm=10, temp_K=300)
6369

6470
if __name__ == "__main__":
65-
# TestPowderExperiment.execute_all()
66-
TestTensorization.execute_all()
71+
TestPowderExperiment.execute_all()
72+
# TestTensorization.execute_all()

xrdpattern/xrd/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .experiment import XrayInfo, Metadata, PowderExperiment
2-
from .data import XrdData, LabelType
1+
from .experiment import XrayInfo, PowderExperiment, LabelType
2+
from .data import XrdData
33
from .xray import XrdAnode, XrayInfo

xrdpattern/xrd/data.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from __future__ import annotations
22

33
import json
4-
from dataclasses import dataclass, fields, field
5-
from enum import Enum
4+
from dataclasses import dataclass, field
5+
from dataclasses import fields
66

77
import numpy as np
88
from numpy.typing import NDArray
99
from orjson import orjson
1010

1111
from xrdpattern.crystal import CrystalStructure
1212
from xrdpattern.serialization import Serializable
13-
from xrdpattern.xrd import Metadata
1413
from xrdpattern.xrd.experiment import PowderExperiment
14+
from xrdpattern.xrd.metadata import Metadata
1515

1616

1717
# -------------------------------------------
@@ -82,25 +82,6 @@ def get_name(self) -> str:
8282
def get_phase(self, phase_num : int) -> CrystalStructure:
8383
return self.powder_experiment.phases[phase_num]
8484

85-
86-
def has_label(self, label_type: LabelType) -> bool:
87-
raise NotImplementedError
88-
# if label_type == LabelType.composition:
89-
# return self.primary_phase.chemical_composition is not None
90-
# if label_type == LabelType.lattice:
91-
# return
92-
# if label_type == LabelType.atom_coords:
93-
# return len(self.primary_phase.basis) > 0
94-
# if label_type == LabelType.spg:
95-
# spg_explicit = self.primary_phase.spacegroup is not None
96-
# spg_implicit = self.has_label(label_type=LabelType.lattice) and self.has_label(
97-
# label_type=LabelType.atom_coords)
98-
# return spg_explicit or spg_implicit
99-
# return False
100-
101-
def is_labeled(self) -> bool:
102-
return any(self.has_label(label_type=lt) for lt in LabelType)
103-
10485
@property
10586
def num_entries(self) -> int:
10687
return len(self.two_theta_values)
@@ -124,18 +105,3 @@ def angular_resolution(self):
124105
@property
125106
def is_simulated(self) -> bool:
126107
return self.powder_experiment.is_simulated
127-
128-
@property
129-
def composition(self) -> str:
130-
comp = ''
131-
for phase in self.powder_experiment.phases:
132-
comp += phase.chemical_composition
133-
return comp
134-
135-
136-
137-
class LabelType(Enum):
138-
spg = "spg"
139-
lattice = "lattice"
140-
atom_coords = "atom_coords"
141-
composition = "composition"

xrdpattern/xrd/experiment.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import math
4-
from dataclasses import dataclass, field
5-
from importlib.metadata import version
4+
from dataclasses import dataclass
5+
from enum import Enum
66
from typing import Optional
77

88
import torch
@@ -69,19 +69,36 @@ def from_cif(cls, cif_content : str) -> PowderExperiment:
6969
# ---------------------------------------------------------
7070
# properties
7171

72-
def is_nonempty(self) -> bool:
73-
xray_info_nonemtpy = not self.xray_info.primary_wavelength is None or not self.xray_info.secondary_wavelength is None
74-
if len(self.phases) == 0:
75-
return xray_info_nonemtpy
72+
@property
73+
def primary_phase(self):
74+
return self.phases[0]
7675

77-
primary_phase = self.phases[0]
78-
composition_nonempty = primary_phase.chemical_composition
76+
def has_label(self, label_type: LabelType) -> bool:
77+
if label_type == LabelType.primary_wavelength:
78+
return not self.xray_info.primary_wavelength is None
79+
if label_type == LabelType.secondary_wavelength:
80+
return not self.xray_info.secondary_wavelength is None
7981

80-
a,b,c = primary_phase.lengths
81-
alpha, beta, gamma = primary_phase.angles
82-
lattice_params_nonempty = not all(math.isnan(x) for x in [a, b, c, alpha, beta, gamma])
83-
crystal_basis_nonempty = len(primary_phase.basis) > 0
84-
return xray_info_nonemtpy or composition_nonempty or lattice_params_nonempty or crystal_basis_nonempty
82+
if len(self.phases) == 0:
83+
return False
84+
85+
if label_type == LabelType.lattice:
86+
return True
87+
elif label_type == LabelType.spg:
88+
return self.primary_phase.spacegroup is not None
89+
elif label_type == LabelType.composition:
90+
return self.primary_phase.chemical_composition is not None
91+
elif label_type == LabelType.temperature:
92+
return self.temp_K is not None
93+
elif label_type == LabelType.crystallite_size:
94+
return self.crystallite_size_nm is not None
95+
elif label_type == LabelType.atom_coords:
96+
return len(self.primary_phase.basis.atom_sites) > 0
97+
else:
98+
raise ValueError(f'Label type {label_type} is not supported.')
99+
100+
def is_labeled(self) -> bool:
101+
return any(self.has_label(label_type=lt) for lt in LabelType)
85102

86103
def __eq__(self, other : PowderExperiment):
87104
return self.to_str() == other.to_str()
@@ -97,56 +114,49 @@ def to_tensor(data):
97114

98115
spg_list = [self.phases[0].spacegroup == j for j in range(1,NUM_SPACEGROUPS+1)]
99116
feature_dict = {
100-
'lengths': to_tensor(self.phases[0].lengths),
101-
'angles' : to_tensor(self.phases[0].angles),
102-
'spg_probabilities' : to_tensor(spg_list),
103-
'crystallite_size' : to_tensor(self.crystallite_size_nm) if self.crystallite_size_nm else None,
104-
'temperature' : to_tensor(self.temp_K) if self.temp_K else None,
105-
'primary_wavelength' : to_tensor(self.xray_info.primary_wavelength) if self.xray_info.primary_wavelength else None,
106-
'secondary_wavelength' : to_tensor(self.xray_info.secondary_wavelength) if self.xray_info.secondary_wavelength else None,
117+
LabelType.lengths.value: to_tensor(self.phases[0].lengths),
118+
LabelType.angles.value : to_tensor(self.phases[0].angles),
119+
LabelType.spg.value : to_tensor(spg_list),
120+
LabelType.crystallite_size.value : to_tensor(self.crystallite_size_nm) if self.crystallite_size_nm else None,
121+
LabelType.temperature.value : to_tensor(self.temp_K) if self.temp_K else None,
122+
LabelType.primary_wavelength.value : to_tensor(self.xray_info.primary_wavelength) if self.xray_info.primary_wavelength else None,
123+
LabelType.secondary_wavelength.value : to_tensor(self.xray_info.secondary_wavelength) if self.xray_info.secondary_wavelength else None,
107124
}
108125
td = ExperimentTensor(feature_dict)
109126
return td
110127

111128

112-
@dataclass
113-
class Metadata(JsonDataclass):
114-
filename: Optional[str] = None
115-
institution: Optional[str] = None
116-
contributor_name: Optional[str] = None
117-
original_file_format: Optional[str] = None
118-
measurement_date: Optional[str] = None
119-
tags: list[str] = field(default_factory=list)
120-
xrdpattern_version: str = field(default_factory=lambda: get_library_version('xrdpattern'))
121-
122-
def __eq__(self, other : Metadata):
123-
s1, s2 = self.to_str(), other.to_str()
124-
return s1 == s2
125-
126-
def remove_filename(self):
127-
self.filename = None
128-
129-
130129
class ExperimentTensor(TensorDict):
131130
def get_lattice_params(self):
132131
lengths, angles = self.get('lengths'), self.get('angles')
133132
return torch.cat([lengths, angles], dim=0)
134133

135134
def get_spg_probabilities(self):
136-
return self.get('spg_probabilities')
135+
return self.get(LabelType.spg.value)
137136

138137
def get_crystallite_size(self):
139-
return self.get('crystallite_size')
138+
return self.get(LabelType.crystallite_size.value)
140139

141140
def get_ambient_temperature(self):
142-
return self.get('temperature')
141+
return self.get(LabelType.temperature.value)
143142

144143
def get_primary_wavelength(self):
145-
return self.get('primary_wavelength')
144+
return self.get(LabelType.primary_wavelength.value)
146145

147146
def get_secondary_wavelength(self):
148-
return self.get('secondary_wavelength')
147+
return self.get(LabelType.secondary_wavelength.value)
148+
149+
150+
class LabelType(Enum):
151+
lattice = "lattice"
152+
lengths = "lengths"
153+
angles = "angles"
154+
atom_coords = "atom_coords"
155+
spg = "spg"
156+
crystallite_size = 'crystallite_size'
157+
temperature = 'temperature'
158+
primary_wavelength = 'primary_wavelength'
159+
secondary_wavelength = 'secondary_wavelength'
160+
composition = "composition"
149161

150162

151-
def get_library_version(library_name : str):
152-
return version(library_name)

xrdpattern/xrd/metadata.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from importlib.metadata import version
5+
from typing import Optional
6+
7+
from xrdpattern.serialization import JsonDataclass
8+
9+
10+
@dataclass
11+
class Metadata(JsonDataclass):
12+
filename: Optional[str] = None
13+
institution: Optional[str] = None
14+
contributor_name: Optional[str] = None
15+
original_file_format: Optional[str] = None
16+
measurement_date: Optional[str] = None
17+
tags: list[str] = field(default_factory=list)
18+
xrdpattern_version: str = field(default_factory=lambda: get_library_version('xrdpattern'))
19+
20+
def __eq__(self, other : Metadata):
21+
s1, s2 = self.to_str(), other.to_str()
22+
return s1 == s2
23+
24+
def remove_filename(self):
25+
self.filename = None
26+
27+
28+
def get_library_version(library_name : str):
29+
return version(library_name)

xrdpattern/xrd/tensorization.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)