22
33import math
44from dataclasses import dataclass , field
5- from enum import Enum
65from importlib .metadata import version
76from typing import Optional
87
9- import torch
10-
11- from xrdpattern .crystal import CrystalStructure , CrystalBasis , AtomSite
8+ from xrdpattern .crystal import CrystalStructure
129from xrdpattern .serialization import JsonDataclass
13- from xrdpattern .xrd .tensorization import LabelTensor
10+ from xrdpattern .xrd .xray import XrayInfo
1411
1512NUM_SPACEGROUPS = 230
1613MAX_ATOMIC_SITES = 100
@@ -26,25 +23,13 @@ class PowderExperiment(JsonDataclass):
2623 temp_in_celcius : Optional [float ] = None
2724
2825 def __post_init__ (self ):
29- if len (self .phases ) == 0 :
30- raise ValueError (f'Material must have at least one phase! Got { len (self .phases )} ' )
31-
3226 if len (self .phases ) == 1 :
3327 self .phases [0 ].phase_fraction = 1
3428
3529 @classmethod
36- def make_empty (cls , is_simulated : bool = False , num_phases : int = 1 ) -> PowderExperiment :
37- phases = []
38- for j in range (num_phases ):
39- lengths = (float ('nan' ),float ('nan' ), float ('nan' ))
40- angles = (float ('nan' ),float ('nan' ), float ('nan' ))
41- base = CrystalBasis .empty ()
42-
43- p = CrystalStructure (lengths = lengths , angles = angles , basis = base )
44- phases .append (p )
45-
30+ def make_empty (cls , is_simulated : bool = False ) -> PowderExperiment :
4631 xray_info = XrayInfo .mk_empty ()
47- return cls (phases = phases , crystallite_size = None , temp_in_celcius = None , xray_info = xray_info , is_simulated = is_simulated )
32+ return cls (phases = [] , xray_info = xray_info , is_simulated = is_simulated )
4833
4934 @classmethod
5035 def from_multi_phase (cls , phases : list [CrystalStructure ]):
@@ -93,101 +78,11 @@ def is_nonempty(self) -> bool:
9378 crystal_basis_nonempty = len (primary_phase .basis ) > 0
9479 return xray_info_nonemtpy or composition_nonempty or lattice_params_nonempty or crystal_basis_nonempty
9580
96- @property
97- def primary_wavelength (self ) -> float :
98- return self .xray_info .primary_wavelength
99-
100- @property
101- def secondary_wavelength (self ) -> float :
102- return self .xray_info .secondary_wavelength
103-
104- def get_list_repr (self ) -> list :
105- list_repr = []
106- structure = self .phases [0 ]
107-
108- a , b , c = structure .lengths
109- alpha , beta , gamma = structure .angles
110- lattice_params = [a , b , c , alpha , beta , gamma ]
111- list_repr += lattice_params
112-
113- base = structure .basis
114- padded_base = self .get_padded_base (base = base , nan_padding = base .is_empty ())
115- for atomic_site in padded_base :
116- list_repr += atomic_site .as_list ()
117-
118- if structure .spacegroup is None :
119- spg_logits_list = [float ('nan' ) for _ in range (NUM_SPACEGROUPS )]
120- else :
121- spg_logits_list = [1000 if j + 1 == structure .spacegroup else 0 for j in range (NUM_SPACEGROUPS )]
122- list_repr += spg_logits_list
123-
124- list_repr += self .xray_info .as_list ()
125- list_repr += [self .is_simulated ]
126-
127- return list_repr
128-
129- @staticmethod
130- def get_padded_base (base : CrystalBasis , nan_padding : bool ) -> CrystalBasis :
131- def make_padding_site ():
132- if nan_padding :
133- site = AtomSite .make_placeholder ()
134- else :
135- site = AtomSite .make_void ()
136- return site
137-
138- delta = MAX_ATOMIC_SITES - len (base )
139- if delta < 0 :
140- raise ValueError (f'Base is too large! Size = { len (base )} exceeds MAX_ATOMIC_SITES = { MAX_ATOMIC_SITES } ' )
141-
142- padded_base = base + [make_padding_site () for _ in range (delta )]
143- return padded_base
144-
145-
146- def to_tensor (self , dtype : torch .dtype = torch .get_default_dtype (), device : torch .device = torch .get_default_device ()) -> LabelTensor :
147- tensor = torch .tensor (self .get_list_repr (), dtype = dtype , device = device )
148- return LabelTensor (tensor )
149-
150-
151- @dataclass
152- class XrayInfo (JsonDataclass ):
153- primary_wavelength : Optional [float ]
154- secondary_wavelength : Optional [float ]
155-
156- @classmethod
157- def mk_empty (cls ):
158- return cls (primary_wavelength = None , secondary_wavelength = None )
159-
160- def as_list (self ) -> list [float ]:
161- return [self .primary_wavelength , self .secondary_wavelength ]
162-
163- @staticmethod
164- def default_ratio () -> float :
165- return 0.5
166-
167-
168- class XrdAnode (Enum ):
169- Cu = "Cu"
170- Mo = "Mo"
171- Cr = "Cr"
172- Fe = "Fe"
173- Co = "Co"
174- Ag = "Ag"
175-
176- def get_wavelengths (self ) -> (float , float ):
177- MATERiAL_TO_WAVELENGTHS = {
178- "Cu" : (1.54439 , 1.54056 ),
179- "Mo" : (0.71359 , 0.70930 ),
180- "Cr" : (2.29361 , 2.28970 ),
181- "Fe" : (1.93998 , 1.93604 ),
182- "Co" : (1.79285 , 1.78896 ),
183- "Ag" : (0.563813 , 0.559421 ),
184- }
185- return MATERiAL_TO_WAVELENGTHS [self .value ]
186-
187- def get_xray_info (self ) -> XrayInfo :
188- primary , secondary = self .get_wavelengths ()
189- return XrayInfo (primary_wavelength = primary , secondary_wavelength = secondary )
81+ def __eq__ (self , other : PowderExperiment ):
82+ return self .to_str () == other .to_str ()
19083
84+ # def to_tensor(self, dtype : torch.dtype = torch.get_default_dtype(), device : torch.device = torch.get_default_device()) -> LabelTensor:
85+ # return LabelTensor(tensor)
19186
19287@dataclass
19388class Metadata (JsonDataclass ):
@@ -206,5 +101,6 @@ def __eq__(self, other : Metadata):
206101 def remove_filename (self ):
207102 self .filename = None
208103
104+
209105def get_library_version (library_name : str ):
210106 return version (library_name )
0 commit comments