11from holytools .devtools import Unittest
2- from xrdpattern .xrd import PowderExperiment
2+ from pymatgen .core import Lattice
3+
4+ from xrdpattern .crystal import CrystalStructure , CrystalBasis
5+ from xrdpattern .xrd import PowderExperiment , XrayInfo
6+ from xrdpattern .xrd .experiment import ExperimentTensor
37
48
59class TestPowderExperiment (Unittest ):
6- def test_empty (self ):
10+ def test_is_empty (self ):
11+
12+ # empty is empty
713 empty_experiment = PowderExperiment .make_empty ()
814 self .assertTrue (not empty_experiment .is_nonempty ())
15+
16+ #nonempty is nonempty
17+
918
19+ def test_has_label (self ):
20+ pass
21+
1022
11- class TestTensorRegions (Unittest ):
23+ class TestTensorization (Unittest ):
1224 def setUp (self ):
13- raise ValueError (f'Tensor regions currently broken! Restore tensorizaation + tests' )
14- # self.label : PowderExperiment = self.make_example_label()
15- # self.label_tensor : LabelTensor = self.label.to_tensor()
16- # self.crystal_structure : CrystalPhase = self.label.powder.phases
25+ self .label : PowderExperiment = self .make_example_label ()
26+ self .crystal_structure : CrystalStructure = self .label .phases [0 ]
27+ self .experiment_tensor : ExperimentTensor = self .label .to_tensordict ()
28+
29+ def test_lattice_params (self ):
30+ expected = (* self .crystal_structure .lengths , * self .crystal_structure .angles )
31+ actual = self .experiment_tensor .get_lattice_params ().tolist ()
32+ print (f'Tensor lattice params are { actual } ' )
33+ for x ,y in zip (expected , actual ):
34+ self .assertEqual (x , y )
35+
36+ def test_spacegroups (self ):
37+ expected = [1.0 if j == self .crystal_structure .spacegroup else 0.0 for j in range (1 ,231 )]
38+ actual = self .experiment_tensor .get_spg_probabilities ().tolist ()
39+ print (f'Tensor spacegroups probabilities are { actual } ; \n Spacegroups tensor length = { len (actual )} ' )
40+ self .assertEqual (actual , expected , f'Expected: { expected } , Actual: { actual } ' )
41+ self .assertEqual (len (actual ), 230 )
42+
43+ def test_wavelength (self ):
44+ expected = self .label .xray_info .primary_wavelength
45+ actual = self .experiment_tensor .get_primary_wavelength ()
46+ print (f'Tensor primary wavelength is { actual } ' )
47+ self .assertAlmostEqual (actual , expected )
1748
49+ expected_secondary = self .label .xray_info .secondary_wavelength
50+ actual_secondary = self .experiment_tensor .get_secondary_wavelength ()
51+ print (f'Tensor secondary wavelength is { actual_secondary } ' )
52+ self .assertAlmostEqual (actual_secondary , expected_secondary )
1853
19- #
20- # def test_lattice_params(self):
21- # expected = (*self.crystal_structure.lengths, *self.crystal_structure.angles)
22- # actual = self.label_tensor.get_lattice_params().tolist()
23- # print(f'Tensor lattice params are {actual}')
24- # for x,y in zip(expected, actual):
25- # self.assertEqual(x, y)
26- #
27- #
28- # def test_atomic_sites(self):
29- # base = self.crystal_structure.base
30- # for i, region in enumerate(self.label.atomic_site_regions):
31- # expected_site = base[i] if i < len(base) else AtomicSite.make_void()
32- # expected = expected_site.as_list()
33- # actual = self.label_tensor.get_atomic_site(i).tolist()
34- # if i == 0:
35- # print(f'Tensor atomic site is {actual}\n Atomic site length = {len(actual)}')
36- # # 3 coordinates, 8 scattering params, 1 occupancy
37- # self.assertEqual(len(actual), 3+8+1)
38- # for x,y, in zip(expected, actual):
39- # if x is None:
40- # self.assertTrue(is_nan(y))
41- # else:
42- # self.assertEqual(x,y)
43- #
44- # def test_spacegroups(self):
45- # expected = [1.0 if j == self.crystal_structure.spacegroup else 0.0 for j in range(1,231)]
46- # actual = self.label_tensor.get_spg_probabilities().tolist()
47- # print(f'Tensor spacegroups are {actual}; \nSpacegroups tensor length = {len(actual)}')
48- # self.assertEqual(actual, expected, f'Expected: {expected}, Actual: {actual}')
49- # self.assertEqual(len(actual), 230)
50- #
51- #
52- # def test_artifacts(self):
53- # expected = [round(num, 2) for num in self.label.artifacts.as_list()]
54- # actual = [round(num, 2) for num in self.label_tensor.get_artifacts().tolist()]
55- # print(f'Tensor artifacts are {actual}')
56- # self.assertEqual(actual, expected)
57- #
58- # def test_total_length(self):
59- # expected = len(self.label.list_repr)
60- # # noinspection PyTypeChecker
61- # actual = len(self.label_tensor)
62- # print(f'Tensor length is {actual}')
63- # self.assertEqual(actual, expected)
64- #
65- #
66- # def test_returns_powder_tensors(self):
67- # tensors = [self.label_tensor.get_atomic_site(0), self.label_tensor.get_lattice_params(), self.label_tensor.get_lattice_params()]
68- # for t in tensors:
69- # print(f'Type of {t} is {type(t)}')
70- # self.assertEqual(t.__class__, LabelTensor)
71- #
72- # @staticmethod
73- # def make_example_label() -> PowderExperiment:
74- # primitives = Lengths(a=3, b=3, c=3)
75- # angles = Angles(alpha=90, beta=90, gamma=90)
76- # base = CrystalBase([AtomicSite.make_void()])
77- # crystal_structure = CrystalPhase(lengths=primitives, angles=angles, base=base)
78- # crystal_structure.spacegroup = 120
79- #
80- # powder = PowderSample(phases=[crystal_structure], crystallite_size=10.0)
81- # artifacts = XRayInfo(primary_wavelength=1.54, secondary_wavelength=1.54)
82- # return PowderExperiment(powder, artifacts, is_simulated=True)
83- #
54+ @staticmethod
55+ def make_example_label () -> PowderExperiment :
56+ lattice = Lattice .from_parameters (3 ,3 ,3 ,90 ,90 ,90 )
57+ basis = CrystalBasis .empty ()
58+ crystal_structure = CrystalStructure (lattice = lattice , basis = basis )
59+ crystal_structure .spacegroup = 120
8460
85- def is_nan ( value ):
86- return value != value
61+ xray_info = XrayInfo . copper_xray ()
62+ return PowderExperiment ( phases = [ crystal_structure ], xray_info = xray_info , crystallite_size_nm = 10 , temp_K = 300 )
8763
8864if __name__ == "__main__" :
89- TestPowderExperiment .execute_all ()
65+ # TestPowderExperiment.execute_all()
66+ TestTensorization .execute_all ()
0 commit comments