Skip to content

Commit 8d537a5

Browse files
committed
Fix band structure normalization
1 parent 52182fd commit 8d537a5

File tree

2 files changed

+56
-88
lines changed

2 files changed

+56
-88
lines changed

src/nomad_simulations/schema_packages/properties/band_structure.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from nomad_simulations.schema_packages.numerical_settings import KSpace
1717
from nomad_simulations.schema_packages.physical_property import PhysicalProperty
1818
from nomad_simulations.schema_packages.utils.utils import check_not_none
19-
from nomad_simulations.schema_packages.utils.electronic import quicksearch_first_value, inner_copy
19+
from nomad_simulations.schema_packages.utils.electronic import inner_copy
2020

2121
configuration = config.get_plugin_entry_point(
2222
'nomad_simulations.schema_packages:nomad_simulations_plugin'
@@ -45,19 +45,19 @@ class ElectronicEigenvalues(BaseElectronicEigenvalues):
4545
value = Quantity(
4646
type=np.float64,
4747
unit='joule',
48-
shape=['level', 'spin'],
48+
shape=['spin', 'level'],
4949
description="""
5050
Value of the electronic eigenvalues.
51-
Rows correspond to the energy levels, and columns correspond to the spin channels.
51+
Dimensions: [spin channel, energy level].
5252
""",
5353
)
5454

5555
occupation = Quantity(
56-
type=unit_float(),
57-
shape=['level', 'spin'],
56+
type=unit_float(dtype=np.float64),
57+
shape=['spin', 'level'],
5858
description="""
5959
Occupation of the electronic eigenvalues, ranging from 0 to 1.
60-
Rows correspond to the energy levels, and columns correspond to the spin channels.
60+
Dimensions: [spin channel, energy level].
6161
""",
6262
) # restructure spin for plotting?
6363

@@ -100,20 +100,23 @@ def resolve_homo_lumo(self) -> None:
100100
"""
101101
Resolve HOMO and LUMO eigenvalues using binary search on sorted eigenvalues.
102102
"""
103-
def process_spin_channel(spin_data):
103+
def process_spin_channel(data: np.ndarray) -> list:
104104
"""Process a single spin channel to find HOMO/LUMO."""
105-
spin_values, spin_occupations = spin_data
106-
lumo_idx = quicksearch_first_value(
107-
spin_occupations, 0.0, tolerance=1e-6
108-
)
109-
110-
return [
111-
spin_values[lumo_idx] if lumo_idx is not None and lumo_idx >= 0 else None,
112-
spin_values[lumo_idx + 1] if lumo_idx is not None and lumo_idx > 0 else None
113-
]
105+
mid = int(len(data) / 2) # ? extra check
106+
values, occupations = data[:mid], data[mid:]
107+
lumo_region = np.where(occupations <= 1e-6)
108+
109+
if lumo_region[0].size > 0:
110+
lumo_idx = np.min(lumo_region)
111+
if lumo_idx > 0:
112+
return [values[lumo_idx], values[lumo_idx - 1]]
113+
else:
114+
return [values[lumo_idx], np.nan]
115+
else:
116+
return [np.nan, np.nan]
114117

115118
# Stack value and occupation arrays along last axis for apply_along_axis
116-
combined_data = np.stack([self.value, self.occupation.magnitude], axis=-1)
119+
combined_data = np.stack([self.value.magnitude, self.occupation], axis=-1)
117120
results = np.apply_along_axis(process_spin_channel, axis=0, arr=combined_data.T)
118121

119122
self.highest_occupied = results[:, 0] * self.value.u
@@ -123,7 +126,7 @@ def pad_out(self) -> None:
123126
"""
124127
Pad out the value and occupation arrays along the spin channel dimension.
125128
"""
126-
spin_index = 2
129+
spin_index = 0 # Spin is now first dimension
127130
if np.array(self.value).shape[spin_index] == 1: # TODO: add model_method spin_polarized
128131
self.value = inner_copy(self.value, 0) # TODO: dynamically set repetition
129132
if np.array(self.occupation).shape[spin_index] == 1: # TODO: add model_method spin_polarized
@@ -147,25 +150,25 @@ class ElectronicBandStructure(BaseElectronicEigenvalues):
147150
value = Quantity(
148151
type=np.float64,
149152
unit='joule',
150-
shape=['level', 'kpoint', 'spin'],
153+
shape=['spin', 'kpoint', 'level'],
151154
description="""
152155
Value of the electronic eigenvalues in the reciprocal space.
153-
Dimensions: [energy level, k-point, spin channel].
156+
Dimensions: [spin channel, k-point, energy level].
154157
""",
155158
)
156159

157160
occupation = Quantity(
158-
type=unit_float(),
159-
shape=['level', 'kpoint', 'spin'],
161+
type=unit_float(dtype=np.float64),
162+
shape=['spin', 'kpoint', 'level'],
160163
description="""
161164
Occupation of the electronic eigenvalues, ranging from 0 to 1.
162-
Dimensions: [energy level, k-point, spin channel].
165+
Dimensions: [spin channel, k-point, energy level].
163166
""",
164167
)
165168

166169
highest_occupied = Quantity(
167170
type=np.float64,
168-
shape=['kpoint', 'spin'],
171+
shape=['spin', 'kpoint'],
169172
unit='joule',
170173
description="""
171174
Highest occupied electronic eigenvalue for each k-point and spin channel. Together with `lowest_unoccupied`, it defines the
@@ -175,7 +178,7 @@ class ElectronicBandStructure(BaseElectronicEigenvalues):
175178

176179
lowest_unoccupied = Quantity(
177180
type=np.float64,
178-
shape=['kpoint', 'spin'],
181+
shape=['spin', 'kpoint'],
179182
unit='joule',
180183
description="""
181184
Lowest unoccupied electronic eigenvalue for each k-point and spin channel. Together with `highest_occupied`, it defines the
@@ -188,39 +191,42 @@ def resolve_homo_lumo(self) -> None:
188191
"""
189192
Resolve HOMO and LUMO eigenvalues using binary search on sorted eigenvalues for band structure.
190193
"""
191-
def process_kpoint_spin(kpoint_spin_data):
194+
def process_spin_kpoint(data: np.ndarray) -> list:
192195
"""Process a single k-point and spin channel to find HOMO/LUMO."""
193-
spin_values, spin_occupations = kpoint_spin_data
194-
lumo_idx = quicksearch_first_value(
195-
spin_occupations, 0.0, tolerance=1e-6
196-
)
197-
198-
return [
199-
spin_values[lumo_idx] if lumo_idx is not None and lumo_idx >= 0 else None,
200-
spin_values[lumo_idx + 1] if lumo_idx is not None and lumo_idx > 0 else None
201-
]
202-
203-
# Stack value and occupation arrays - shape: [level, kpoint, spin, 2]
204-
# Apply along level axis (axis=0) for each k-point and spin combination
205-
# Reshape to combine kpoint and spin dimensions for processing
206-
combined_data = np.stack([self.value, self.occupation.magnitude], axis=-1)
207-
n_levels, n_kpoints, n_spins, _ = combined_data.shape
208-
reshaped_data = combined_data.transpose(1, 2, 0, 3).reshape(n_kpoints * n_spins, n_levels, 2)
209-
results = np.apply_along_axis(process_kpoint_spin, axis=1, arr=reshaped_data)
196+
mid = int(len(data) / 2) # ? extra check
197+
values, occupations = data[:mid], data[mid:]
198+
lumo_region = np.where(occupations <= 1e-6)
199+
200+
if lumo_region[0].size > 0:
201+
lumo_idx = np.min(lumo_region)
202+
if lumo_idx > 0:
203+
return [values[lumo_idx], values[lumo_idx - 1]]
204+
else:
205+
return [values[lumo_idx], np.nan]
206+
else:
207+
return [np.nan, np.nan]
208+
209+
210+
n_spins, n_kpoints, n_levels = self.value.shape
210211

211-
# Reshape back to [kpoint, spin, 2] then extract homo/lumo
212-
results = results.reshape(n_kpoints, n_spins, 2)
212+
# Stack along last axis to get [n_spins, n_kpoints, n_levels, 2]
213+
combined_data = np.stack([self.value.magnitude, self.occupation], axis=2)
214+
reshaped_data = combined_data.reshape(n_spins * n_kpoints, n_levels * 2)
215+
216+
results = np.apply_along_axis(process_spin_kpoint, axis=1, arr=reshaped_data)
217+
results = results.reshape(n_spins, n_kpoints, 2)
218+
213219
self.highest_occupied = results[:, :, 0] * self.value.u
214220
self.lowest_unoccupied = results[:, :, 1] * self.value.u
215221

216222
def pad_out(self) -> None:
217223
"""
218224
Pad out the value and occupation arrays along the spin channel dimension.
219225
"""
220-
spin_index = 2
221-
if np.array(self.value).shape[spin_index] == 1: # TODO: add model_method spin_polarized
226+
spin_index = 0
227+
if self.value.shape[spin_index] == 1: # TODO: add model_method spin_polarized
222228
self.value = inner_copy(self.value, 0) # TODO: dynamically set repetition
223-
if np.array(self.occupation).shape[spin_index] == 1: # TODO: add model_method spin_polarized
229+
if self.occupation.shape[spin_index] == 1: # TODO: add model_method spin_polarized
224230
self.occupation = inner_copy(self.occupation, 0) # TODO: dynamically set repetition
225231

226232
def resolve_reciprocal_cell(self) -> pint.Quantity | None: # ? remove

src/nomad_simulations/schema_packages/utils/electronic.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,8 @@
44

55
import numpy as np
66

7-
8-
def quicksearch_first_value(
9-
arr: np.ndarray, target: float, tolerance: float = 1e-10
10-
) -> int | None:
11-
"""
12-
Find the index of the first occurrence of a target value in a sorted array using binary search.
13-
14-
This function assumes the array is sorted and uses binary search to efficiently
15-
locate the first element that equals the target (within tolerance).
16-
17-
Args:
18-
arr: Sorted numpy array to search
19-
target: Target value to search for
20-
tolerance: Tolerance for considering values equal
21-
22-
Returns:
23-
Index of first occurrence of target value, or None if not found
24-
"""
25-
if len(arr) == 0:
26-
return None
27-
28-
left, right = 0, len(arr) - 1
29-
result = None
30-
31-
while left <= right:
32-
mid = (left + right) // 2
33-
34-
if abs(arr[mid] - target) <= tolerance:
35-
result = mid
36-
right = mid - 1
37-
elif arr[mid] < target - tolerance:
38-
left = mid + 1
39-
else:
40-
right = mid - 1
41-
42-
return result
43-
44-
457
def inner_copy(
46-
tensor: np.ndarray, rank_selection: int | tuple[int] | slice, repeat: int = 1
8+
tensor: np.ndarray, rank_selection: int | tuple[int] | slice, repeat: int = 0
479
) -> np.ndarray:
4810
"""
4911
Take a chunk of a high-ranked array and extend it with exact copies of the selection.
@@ -54,7 +16,7 @@ def inner_copy(
5416
Args:
5517
tensor: Input `numpy` array to copy from
5618
rank_selection: `int`, `tuple`, `slice` specifying which elements to select
57-
repeat: Number of times to repeat the selection (default: 1)
19+
repeat: Number of times to repeat the selection. Counting starts from 0 (default: 0)
5820
5921
Example:
6022
>>> arr = np.array([[1, 2], [3, 4], [5, 6]])

0 commit comments

Comments
 (0)