Skip to content

Commit 52182fd

Browse files
committed
- Clean eigenvalues and band structure base sections: the latter is the same as the former but with a k-points level
- Add various generic utilities: `quicksearch_first_value`, `inner_copy`, `check_not_none`
1 parent f66bb44 commit 52182fd

File tree

3 files changed

+255
-32
lines changed

3 files changed

+255
-32
lines changed

src/nomad_simulations/schema_packages/properties/band_structure.py

Lines changed: 144 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99

1010
if TYPE_CHECKING:
1111
from nomad.datamodel.datamodel import EntryArchive
12-
from nomad.metainfo import Context, Section
1312
from structlog.stdlib import BoundLogger
1413

1514
from nomad_simulations.schema_packages.atoms_state import AtomsState, OrbitalsState
1615
from nomad_simulations.schema_packages.data_types import unit_float
1716
from nomad_simulations.schema_packages.numerical_settings import KSpace
1817
from nomad_simulations.schema_packages.physical_property import PhysicalProperty
18+
from nomad_simulations.schema_packages.utils.utils import check_not_none
19+
from nomad_simulations.schema_packages.utils.electronic import quicksearch_first_value, inner_copy
1920

2021
configuration = config.get_plugin_entry_point(
2122
'nomad_simulations.schema_packages:nomad_simulations_plugin'
@@ -27,60 +28,56 @@ class BaseElectronicEigenvalues(PhysicalProperty):
2728
A base section used to define basic quantities for the `ElectronicEigenvalues` and `ElectronicBandStructure` properties.
2829
"""
2930

30-
iri = ''
31-
3231
n_bands = Quantity(
3332
type=np.int32,
3433
description="""
3534
Number of bands / eigenvalues.
3635
""",
3736
) # TODO: remove
3837

39-
value = Quantity(
40-
type=np.float64,
41-
unit='joule',
42-
shape=['*', '*'],
43-
description="""
44-
Value of the electronic eigenvalues.
45-
""",
46-
)
47-
48-
4938
class ElectronicEigenvalues(BaseElectronicEigenvalues):
5039
""" """
5140

5241
iri = 'http://fairmat-nfdi.eu/taxonomy/ElectronicEigenvalues'
5342

54-
spin_channel = Quantity(
55-
type=np.int32,
43+
# TODO: add spin annotation from @EBB2675
44+
45+
value = Quantity(
46+
type=np.float64,
47+
unit='joule',
48+
shape=['level', 'spin'],
5649
description="""
57-
Spin channel of the corresponding electronic eigenvalues. It can take values of 0 or 1.
50+
Value of the electronic eigenvalues.
51+
Rows correspond to the energy levels, and columns correspond to the spin channels.
5852
""",
5953
)
6054

6155
occupation = Quantity(
6256
type=unit_float(),
63-
shape=['*'],
57+
shape=['level', 'spin'],
6458
description="""
65-
Occupation of the electronic eigenvalues.
59+
Occupation of the electronic eigenvalues, ranging from 0 to 1.
60+
Rows correspond to the energy levels, and columns correspond to the spin channels.
6661
""",
6762
) # restructure spin for plotting?
6863

6964
highest_occupied = Quantity(
7065
type=np.float64,
66+
shape=['spin'],
7167
unit='joule',
7268
description="""
73-
Highest occupied electronic eigenvalue. Together with `lowest_unoccupied`, it defines the
74-
electronic band gap.
69+
Highest occupied electronic eigenvalue for each spin channel. Together with `lowest_unoccupied`, it defines the
70+
electronic band gap. Automatically resolved using binary search on sorted eigenvalues.
7571
""",
7672
)
7773

7874
lowest_unoccupied = Quantity(
7975
type=np.float64,
76+
shape=['spin'],
8077
unit='joule',
8178
description="""
82-
Lowest unoccupied electronic eigenvalue. Together with `highest_occupied`, it defines the
83-
electronic band gap.
79+
Lowest unoccupied electronic eigenvalue for each spin channel. Together with `highest_occupied`, it defines the
80+
electronic band gap. Automatically resolved using binary search on sorted eigenvalues.
8481
""",
8582
)
8683

@@ -98,17 +95,135 @@ class ElectronicEigenvalues(BaseElectronicEigenvalues):
9895
""",
9996
)
10097

98+
@check_not_none('self.value', 'self.occupation')
99+
def resolve_homo_lumo(self) -> None:
100+
"""
101+
Resolve HOMO and LUMO eigenvalues using binary search on sorted eigenvalues.
102+
"""
103+
def process_spin_channel(spin_data):
104+
"""Process a single spin channel to find HOMO/LUMO."""
105+
spin_values, spin_occupations = spin_data
106+
lumo_idx = quicksearch_first_value(
107+
spin_occupations, 0.0, tolerance=1e-6
108+
)
109+
110+
return [
111+
spin_values[lumo_idx] if lumo_idx is not None and lumo_idx >= 0 else None,
112+
spin_values[lumo_idx + 1] if lumo_idx is not None and lumo_idx > 0 else None
113+
]
114+
115+
# Stack value and occupation arrays along last axis for apply_along_axis
116+
combined_data = np.stack([self.value, self.occupation.magnitude], axis=-1)
117+
results = np.apply_along_axis(process_spin_channel, axis=0, arr=combined_data.T)
118+
119+
self.highest_occupied = results[:, 0] * self.value.u
120+
self.lowest_unoccupied = results[:, 1] * self.value.u
121+
122+
def pad_out(self) -> None:
123+
"""
124+
Pad out the value and occupation arrays along the spin channel dimension.
125+
"""
126+
spin_index = 2
127+
if np.array(self.value).shape[spin_index] == 1: # TODO: add model_method spin_polarized
128+
self.value = inner_copy(self.value, 0) # TODO: dynamically set repetition
129+
if np.array(self.occupation).shape[spin_index] == 1: # TODO: add model_method spin_polarized
130+
self.occupation = inner_copy(self.occupation, 0) # TODO: dynamically set repetition
131+
132+
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
133+
super().normalize(archive, logger)
134+
self.pad_out()
135+
self.resolve_homo_lumo()
136+
101137

102-
class ElectronicBandStructure(ElectronicEigenvalues):
138+
class ElectronicBandStructure(BaseElectronicEigenvalues):
103139
"""
104140
Accessible energies by the charges (electrons and holes) in the reciprocal space.
105141
"""
106142

107143
iri = 'http://fairmat-nfdi.eu/taxonomy/ElectronicBandStructure'
108144

109-
k_path = SubSection(sub_section=KMesh.m_def)
145+
kpoint = SubSection(sub_section=KMesh.m_def)
110146

111-
def resolve_reciprocal_cell(self) -> Optional[pint.Quantity]:
147+
value = Quantity(
148+
type=np.float64,
149+
unit='joule',
150+
shape=['level', 'kpoint', 'spin'],
151+
description="""
152+
Value of the electronic eigenvalues in the reciprocal space.
153+
Dimensions: [energy level, k-point, spin channel].
154+
""",
155+
)
156+
157+
occupation = Quantity(
158+
type=unit_float(),
159+
shape=['level', 'kpoint', 'spin'],
160+
description="""
161+
Occupation of the electronic eigenvalues, ranging from 0 to 1.
162+
Dimensions: [energy level, k-point, spin channel].
163+
""",
164+
)
165+
166+
highest_occupied = Quantity(
167+
type=np.float64,
168+
shape=['kpoint', 'spin'],
169+
unit='joule',
170+
description="""
171+
Highest occupied electronic eigenvalue for each k-point and spin channel. Together with `lowest_unoccupied`, it defines the
172+
electronic band gap. Automatically resolved using binary search on sorted eigenvalues.
173+
""",
174+
)
175+
176+
lowest_unoccupied = Quantity(
177+
type=np.float64,
178+
shape=['kpoint', 'spin'],
179+
unit='joule',
180+
description="""
181+
Lowest unoccupied electronic eigenvalue for each k-point and spin channel. Together with `highest_occupied`, it defines the
182+
electronic band gap. Automatically resolved using binary search on sorted eigenvalues.
183+
""",
184+
)
185+
186+
@check_not_none('self.value', 'self.occupation')
187+
def resolve_homo_lumo(self) -> None:
188+
"""
189+
Resolve HOMO and LUMO eigenvalues using binary search on sorted eigenvalues for band structure.
190+
"""
191+
def process_kpoint_spin(kpoint_spin_data):
192+
"""Process a single k-point and spin channel to find HOMO/LUMO."""
193+
spin_values, spin_occupations = kpoint_spin_data
194+
lumo_idx = quicksearch_first_value(
195+
spin_occupations, 0.0, tolerance=1e-6
196+
)
197+
198+
return [
199+
spin_values[lumo_idx] if lumo_idx is not None and lumo_idx >= 0 else None,
200+
spin_values[lumo_idx + 1] if lumo_idx is not None and lumo_idx > 0 else None
201+
]
202+
203+
# Stack value and occupation arrays - shape: [level, kpoint, spin, 2]
204+
# Apply along level axis (axis=0) for each k-point and spin combination
205+
# Reshape to combine kpoint and spin dimensions for processing
206+
combined_data = np.stack([self.value, self.occupation.magnitude], axis=-1)
207+
n_levels, n_kpoints, n_spins, _ = combined_data.shape
208+
reshaped_data = combined_data.transpose(1, 2, 0, 3).reshape(n_kpoints * n_spins, n_levels, 2)
209+
results = np.apply_along_axis(process_kpoint_spin, axis=1, arr=reshaped_data)
210+
211+
# Reshape back to [kpoint, spin, 2] then extract homo/lumo
212+
results = results.reshape(n_kpoints, n_spins, 2)
213+
self.highest_occupied = results[:, :, 0] * self.value.u
214+
self.lowest_unoccupied = results[:, :, 1] * self.value.u
215+
216+
def pad_out(self) -> None:
217+
"""
218+
Pad out the value and occupation arrays along the spin channel dimension.
219+
"""
220+
spin_index = 2
221+
if np.array(self.value).shape[spin_index] == 1: # TODO: add model_method spin_polarized
222+
self.value = inner_copy(self.value, 0) # TODO: dynamically set repetition
223+
if np.array(self.occupation).shape[spin_index] == 1: # TODO: add model_method spin_polarized
224+
self.occupation = inner_copy(self.occupation, 0) # TODO: dynamically set repetition
225+
226+
def resolve_reciprocal_cell(self) -> pint.Quantity | None: # ? remove
112227
"""
113228
Resolve the reciprocal cell from the `KSpace` numerical settings section.
114229
@@ -120,18 +235,15 @@ def resolve_reciprocal_cell(self) -> Optional[pint.Quantity]:
120235
)
121236
if numerical_settings is None:
122237
return None
123-
k_space = None
238+
124239
for setting in numerical_settings:
125240
if isinstance(setting, KSpace):
126-
k_space = setting
127-
break
128-
if k_space is None:
129-
return None
130-
return k_space
241+
return setting
131242

132243
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
133244
super().normalize(archive, logger)
134-
self.reciprocal_cell = self.resolve_reciprocal_cell()
245+
self.pad_out()
246+
self.resolve_homo_lumo()
135247

136248

137249
# defunct
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Electronic structure utility functions.
3+
"""
4+
5+
import numpy as np
6+
7+
8+
def quicksearch_first_value(
9+
arr: np.ndarray, target: float, tolerance: float = 1e-10
10+
) -> int | None:
11+
"""
12+
Find the index of the first occurrence of a target value in a sorted array using binary search.
13+
14+
This function assumes the array is sorted and uses binary search to efficiently
15+
locate the first element that equals the target (within tolerance).
16+
17+
Args:
18+
arr: Sorted numpy array to search
19+
target: Target value to search for
20+
tolerance: Tolerance for considering values equal
21+
22+
Returns:
23+
Index of first occurrence of target value, or None if not found
24+
"""
25+
if len(arr) == 0:
26+
return None
27+
28+
left, right = 0, len(arr) - 1
29+
result = None
30+
31+
while left <= right:
32+
mid = (left + right) // 2
33+
34+
if abs(arr[mid] - target) <= tolerance:
35+
result = mid
36+
right = mid - 1
37+
elif arr[mid] < target - tolerance:
38+
left = mid + 1
39+
else:
40+
right = mid - 1
41+
42+
return result
43+
44+
45+
def inner_copy(
46+
tensor: np.ndarray, rank_selection: int | tuple[int] | slice, repeat: int = 1
47+
) -> np.ndarray:
48+
"""
49+
Take a chunk of a high-ranked array and extend it with exact copies of the selection.
50+
51+
This function selects a portion of a tensor along its first axis and repeats it
52+
the specified number of times, effectively extending the tensor.
53+
54+
Args:
55+
tensor: Input `numpy` array to copy from
56+
rank_selection: `int`, `tuple`, `slice` specifying which elements to select
57+
repeat: Number of times to repeat the selection (default: 1)
58+
59+
Example:
60+
>>> arr = np.array([[1, 2], [3, 4], [5, 6]])
61+
>>> inner_copy(arr, slice(0, None), repeat=2)
62+
array([[1, 2], [1, 2], [1, 2]])
63+
"""
64+
if tensor.size == 0:
65+
return tensor
66+
67+
selected_chunk = tensor[rank_selection]
68+
69+
# If selection results in 1D array, ensure it maintains proper shape
70+
if selected_chunk.ndim == tensor.ndim - 1:
71+
selected_chunk = np.expand_dims(selected_chunk, axis=0)
72+
73+
repeated_chunks = np.tile(selected_chunk, (repeat + 1, *([1] * (tensor.ndim - 1))))
74+
return np.concatenate([tensor, repeated_chunks], axis=0)

src/nomad_simulations/schema_packages/utils/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,40 @@ def wrapper(self, other) -> bool:
147147
return False
148148

149149
return wrapper
150+
151+
152+
def check_not_none(*attributes: str) -> 'Callable':
153+
"""
154+
Decorator that checks if specified object or class attributes are not `None`.
155+
Returns `None` if any of the specified attributes are `None`, otherwise executes the function.
156+
157+
Args:
158+
*attributes: Names of attributes to check for None values
159+
Use 'input.<attribute>' for input attributes
160+
Use 'self.<attribute>' for object attributes
161+
Use 'class.<attribute>' for class attributes
162+
Use '<attribute>' for global attributes
163+
"""
164+
def decorator(func: 'Callable') -> 'Callable':
165+
def wrapper(self, *args, **kwargs):
166+
for attr in attributes:
167+
# Remove prefix and set source
168+
if attr.startswith('input.'):
169+
attr_name = attr[6:]
170+
source = args[0]
171+
elif attr.startswith('self.'):
172+
attr_name = attr[5:]
173+
source = self
174+
elif attr.startswith('class.'):
175+
attr_name = attr[6:]
176+
source = self.__class__
177+
else:
178+
attr_name = attr
179+
source = globals()
180+
if not hasattr(source, attr_name) or getattr(source, attr_name) is None:
181+
return None
182+
183+
return func(self, *args, **kwargs)
184+
185+
return wrapper
186+
return decorator

0 commit comments

Comments
 (0)