11from __future__ import annotations
22
33import math
4- from dataclasses import dataclass , field
5- from importlib . metadata import version
4+ from dataclasses import dataclass
5+ from enum import Enum
66from typing import Optional
77
88import torch
@@ -69,19 +69,36 @@ def from_cif(cls, cif_content : str) -> PowderExperiment:
6969 # ---------------------------------------------------------
7070 # properties
7171
72- def is_nonempty (self ) -> bool :
73- xray_info_nonemtpy = not self .xray_info .primary_wavelength is None or not self .xray_info .secondary_wavelength is None
74- if len (self .phases ) == 0 :
75- return xray_info_nonemtpy
72+ @property
73+ def primary_phase (self ):
74+ return self .phases [0 ]
7675
77- primary_phase = self .phases [0 ]
78- composition_nonempty = primary_phase .chemical_composition
76+ def has_label (self , label_type : LabelType ) -> bool :
77+ if label_type == LabelType .primary_wavelength :
78+ return not self .xray_info .primary_wavelength is None
79+ if label_type == LabelType .secondary_wavelength :
80+ return not self .xray_info .secondary_wavelength is None
7981
80- a ,b ,c = primary_phase .lengths
81- alpha , beta , gamma = primary_phase .angles
82- lattice_params_nonempty = not all (math .isnan (x ) for x in [a , b , c , alpha , beta , gamma ])
83- crystal_basis_nonempty = len (primary_phase .basis ) > 0
84- return xray_info_nonemtpy or composition_nonempty or lattice_params_nonempty or crystal_basis_nonempty
82+ if len (self .phases ) == 0 :
83+ return False
84+
85+ if label_type == LabelType .lattice :
86+ return True
87+ elif label_type == LabelType .spg :
88+ return self .primary_phase .spacegroup is not None
89+ elif label_type == LabelType .composition :
90+ return self .primary_phase .chemical_composition is not None
91+ elif label_type == LabelType .temperature :
92+ return self .temp_K is not None
93+ elif label_type == LabelType .crystallite_size :
94+ return self .crystallite_size_nm is not None
95+ elif label_type == LabelType .atom_coords :
96+ return len (self .primary_phase .basis .atom_sites ) > 0
97+ else :
98+ raise ValueError (f'Label type { label_type } is not supported.' )
99+
100+ def is_labeled (self ) -> bool :
101+ return any (self .has_label (label_type = lt ) for lt in LabelType )
85102
86103 def __eq__ (self , other : PowderExperiment ):
87104 return self .to_str () == other .to_str ()
@@ -97,56 +114,49 @@ def to_tensor(data):
97114
98115 spg_list = [self .phases [0 ].spacegroup == j for j in range (1 ,NUM_SPACEGROUPS + 1 )]
99116 feature_dict = {
100- ' lengths' : to_tensor (self .phases [0 ].lengths ),
101- ' angles' : to_tensor (self .phases [0 ].angles ),
102- 'spg_probabilities' : to_tensor (spg_list ),
103- ' crystallite_size' : to_tensor (self .crystallite_size_nm ) if self .crystallite_size_nm else None ,
104- ' temperature' : to_tensor (self .temp_K ) if self .temp_K else None ,
105- ' primary_wavelength' : to_tensor (self .xray_info .primary_wavelength ) if self .xray_info .primary_wavelength else None ,
106- ' secondary_wavelength' : to_tensor (self .xray_info .secondary_wavelength ) if self .xray_info .secondary_wavelength else None ,
117+ LabelType . lengths . value : to_tensor (self .phases [0 ].lengths ),
118+ LabelType . angles . value : to_tensor (self .phases [0 ].angles ),
119+ LabelType . spg . value : to_tensor (spg_list ),
120+ LabelType . crystallite_size . value : to_tensor (self .crystallite_size_nm ) if self .crystallite_size_nm else None ,
121+ LabelType . temperature . value : to_tensor (self .temp_K ) if self .temp_K else None ,
122+ LabelType . primary_wavelength . value : to_tensor (self .xray_info .primary_wavelength ) if self .xray_info .primary_wavelength else None ,
123+ LabelType . secondary_wavelength . value : to_tensor (self .xray_info .secondary_wavelength ) if self .xray_info .secondary_wavelength else None ,
107124 }
108125 td = ExperimentTensor (feature_dict )
109126 return td
110127
111128
112- @dataclass
113- class Metadata (JsonDataclass ):
114- filename : Optional [str ] = None
115- institution : Optional [str ] = None
116- contributor_name : Optional [str ] = None
117- original_file_format : Optional [str ] = None
118- measurement_date : Optional [str ] = None
119- tags : list [str ] = field (default_factory = list )
120- xrdpattern_version : str = field (default_factory = lambda : get_library_version ('xrdpattern' ))
121-
122- def __eq__ (self , other : Metadata ):
123- s1 , s2 = self .to_str (), other .to_str ()
124- return s1 == s2
125-
126- def remove_filename (self ):
127- self .filename = None
128-
129-
130129class ExperimentTensor (TensorDict ):
131130 def get_lattice_params (self ):
132131 lengths , angles = self .get ('lengths' ), self .get ('angles' )
133132 return torch .cat ([lengths , angles ], dim = 0 )
134133
135134 def get_spg_probabilities (self ):
136- return self .get ('spg_probabilities' )
135+ return self .get (LabelType . spg . value )
137136
138137 def get_crystallite_size (self ):
139- return self .get (' crystallite_size' )
138+ return self .get (LabelType . crystallite_size . value )
140139
141140 def get_ambient_temperature (self ):
142- return self .get (' temperature' )
141+ return self .get (LabelType . temperature . value )
143142
144143 def get_primary_wavelength (self ):
145- return self .get (' primary_wavelength' )
144+ return self .get (LabelType . primary_wavelength . value )
146145
147146 def get_secondary_wavelength (self ):
148- return self .get ('secondary_wavelength' )
147+ return self .get (LabelType .secondary_wavelength .value )
148+
149+
150+ class LabelType (Enum ):
151+ lattice = "lattice"
152+ lengths = "lengths"
153+ angles = "angles"
154+ atom_coords = "atom_coords"
155+ spg = "spg"
156+ crystallite_size = 'crystallite_size'
157+ temperature = 'temperature'
158+ primary_wavelength = 'primary_wavelength'
159+ secondary_wavelength = 'secondary_wavelength'
160+ composition = "composition"
149161
150162
151- def get_library_version (library_name : str ):
152- return version (library_name )
0 commit comments