Skip to content

Commit 78282ce

Browse files
General: Fixed import errors
1 parent 28e9ade commit 78282ce

File tree

6 files changed

+22
-79
lines changed

6 files changed

+22
-79
lines changed

exports.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +0,0 @@
1-
@classmethod
2-
def make_empty(cls, is_simulated: bool = False, num_phases: int = 1) -> PowderExperiment:
3-
phases = []
4-
for j in range(num_phases):
5-
lengths = (float('nan'), float('nan'), float('nan'))
6-
angles = (float('nan'), float('nan'), float('nan'))
7-
base = CrystalBasis.empty()
8-
9-
p = CrystalStructure(lengths=lengths, angles=angles, basis=base)
10-
phases.append(p)
11-
12-
xray_info = XrayInfo.mk_empty()
13-
return cls(phases=phases, crystallite_size=None, temp_in_celcius=None, xray_info=xray_info,
14-
is_simulated=is_simulated)
15-
16-
17-
def get_list_repr(self) -> list:
18-
list_repr = []
19-
structure = self.phases[0]
20-
21-
a, b, c = structure.lengths
22-
alpha, beta, gamma = structure.angles
23-
lattice_params = [a, b, c, alpha, beta, gamma]
24-
list_repr += lattice_params
25-
26-
base = structure.basis
27-
padded_base = self.get_padded_base(base=base, nan_padding=base.is_empty())
28-
for atomic_site in padded_base:
29-
list_repr += atomic_site.as_list()
30-
31-
if structure.spacegroup is None:
32-
spg_logits_list = [float('nan') for _ in range(NUM_SPACEGROUPS)]
33-
else:
34-
spg_logits_list = [1000 if j + 1 == structure.spacegroup else 0 for j in range(NUM_SPACEGROUPS)]
35-
list_repr += spg_logits_list
36-
37-
list_repr += self.xray_info.as_list()
38-
list_repr += [self.is_simulated]
39-
40-
return list_repr
41-
42-
@staticmethod
43-
def get_padded_base(base: CrystalBasis, nan_padding : bool) -> CrystalBasis:
44-
def make_padding_site():
45-
if nan_padding:
46-
site = AtomSite.make_placeholder()
47-
else:
48-
site = AtomSite.make_void()
49-
return site
50-
51-
delta = MAX_ATOMIC_SITES - len(base)
52-
if delta < 0:
53-
raise ValueError(f'Base is too large! Size = {len(base)} exceeds MAX_ATOMIC_SITES = {MAX_ATOMIC_SITES}')
54-
55-
padded_base = base + [make_padding_site() for _ in range(delta)]
56-
return padded_base

tests/t_crystal/t_basis.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from holytools.devtools import Unittest
22

3-
from xrdpattern.crystal import CrystalBasis, AtomSite, Atom, CrystalExamples
3+
from xrdpattern.crystal import CrystalBasis, AtomSite, CrystalExamples
44

55

66
# ---------------------------------------------------------
@@ -9,8 +9,6 @@ class TestCrystalBase(Unittest):
99
def test_scattering_params(self):
1010
mock_base = CrystalBasis([
1111
AtomSite(x=0.5, y=0.5, z=0.5, occupancy=1.0, species_str="Si0"),
12-
AtomSite(x=0.1, y=0.1, z=0.1, occupancy=1.0, species_str=Atom.placeholder_symbol),
13-
AtomSite(x=0.9, y=0.9, z=0.9, occupancy=1.0, species_str=Atom.void_symbol)
1412
])
1513
real_base = CrystalExamples.get_base()
1614

tests/t_crystal/t_calculation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ def to_clustered_pymatgen(crystal : CrystalStructure) -> Structure:
8181
alpha, beta, gamma = crystal.angles
8282
lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)
8383

84-
non_void_sites = crystal.basis.get_non_void_sites()
85-
8684
EPSILON = 0.001
8785
clusters: list[list[AtomSite]] = []
8886

@@ -92,7 +90,7 @@ def matching_cluster(the_site):
9290
return cl
9391
return None
9492

95-
for site in non_void_sites:
93+
for site in crystal.basis.atom_sites:
9694
match = matching_cluster(site)
9795
if match:
9896
match.append(site)

tests/t_parsing/t_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def test_obj_ok(self):
2121

2222
def test_metadata_ok(self):
2323
metadata = self.pattern.powder_experiment
24-
primary_wavelength = metadata.primary_wavelength
25-
secondary_wavelength = metadata.secondary_wavelength
24+
primary_wavelength = metadata.xray_info.primary_wavelength
25+
secondary_wavelength = metadata.xray_info.secondary_wavelength
2626

2727
for prop in [primary_wavelength, secondary_wavelength]:
2828
self.assertIsNotNone(obj=prop)

xrdpattern/xrd/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from .experiment import XrayInfo, XrdAnode, Metadata, PowderExperiment
1+
from .experiment import XrayInfo, Metadata, PowderExperiment
22
from .data import XrdData, LabelType
3+
from .xray import XrdAnode, XrayInfo

xrdpattern/xrd/data.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import json
4-
import math
54
from dataclasses import dataclass, fields, field
65
from enum import Enum
76

@@ -11,8 +10,9 @@
1110

1211
from xrdpattern.crystal import CrystalStructure
1312
from xrdpattern.serialization import Serializable
14-
from xrdpattern.xrd.experiment import PowderExperiment
1513
from xrdpattern.xrd import Metadata
14+
from xrdpattern.xrd.experiment import PowderExperiment
15+
1616

1717
# -------------------------------------------
1818

@@ -82,19 +82,21 @@ def get_name(self) -> str:
8282
def get_phase(self, phase_num : int) -> CrystalStructure:
8383
return self.powder_experiment.phases[phase_num]
8484

85+
8586
def has_label(self, label_type: LabelType) -> bool:
86-
if label_type == LabelType.composition:
87-
return self.primary_phase.chemical_composition is not None
88-
if label_type == LabelType.lattice:
89-
return
90-
if label_type == LabelType.atom_coords:
91-
return len(self.primary_phase.basis) > 0
92-
if label_type == LabelType.spg:
93-
spg_explicit = self.primary_phase.spacegroup is not None
94-
spg_implicit = self.has_label(label_type=LabelType.lattice) and self.has_label(
95-
label_type=LabelType.atom_coords)
96-
return spg_explicit or spg_implicit
97-
return False
87+
raise NotImplementedError
88+
# if label_type == LabelType.composition:
89+
# return self.primary_phase.chemical_composition is not None
90+
# if label_type == LabelType.lattice:
91+
# return
92+
# if label_type == LabelType.atom_coords:
93+
# return len(self.primary_phase.basis) > 0
94+
# if label_type == LabelType.spg:
95+
# spg_explicit = self.primary_phase.spacegroup is not None
96+
# spg_implicit = self.has_label(label_type=LabelType.lattice) and self.has_label(
97+
# label_type=LabelType.atom_coords)
98+
# return spg_explicit or spg_implicit
99+
# return False
98100

99101
def is_labeled(self) -> bool:
100102
return any(self.has_label(label_type=lt) for lt in LabelType)

0 commit comments

Comments
 (0)