Skip to content

Commit b2163ac

Browse files
committed
Get wavelengths to plot
1 parent 12e8f1e commit b2163ac

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
import numpy as np
13
import iklayout # type: ignore
24
import matplotlib.pyplot as plt # type: ignore
35
from ipywidgets import interactive, IntSlider # type: ignore
@@ -301,3 +303,53 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
301303
print("Statement:", unf_stmt.text)
302304
print("Formalization: UNFORMALIZABLE")
303305
print("\n-----------------------------------\n")
306+
307+
308+
def _str_units_to_float(str_units: str) -> float:
309+
unit_conversions = {
310+
"nm": 1e-3,
311+
"um": 1,
312+
"mm": 1e3,
313+
"m": 1e6,
314+
}
315+
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units)
316+
numeric_value = float(match.group(1))
317+
unit = match.group(2)
318+
return float(numeric_value * unit_conversions[unit])
319+
320+
321+
def get_wavelengths_to_plot(
322+
statements: StatementDictionary, num_samples: int = 100
323+
) -> tuple[list[float], list[float]]:
324+
"""
325+
Get the wavelengths to plot based on the statements.
326+
327+
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
328+
"""
329+
330+
min_wl = np.inf
331+
max_wl = -np.inf
332+
vlines = set()
333+
for stmt in statements.parameter_constraints + statements.cost_functions:
334+
if stmt.formalization is not None and stmt.formalization.mapping is not None:
335+
for comp in stmt.formalization.mapping.values():
336+
if "wavelengths" in comp.arguments:
337+
vlines = vlines | {
338+
_str_units_to_float(wl) for wl in comp.arguments["wavelengths"]
339+
}
340+
if "wavelength_range" in comp.arguments:
341+
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
342+
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
343+
if vlines:
344+
min_wl = min(min_wl, min(vlines))
345+
max_wl = max(max_wl, max(vlines))
346+
if min_wl >= max_wl:
347+
avg_wl = sum(vlines) / len(vlines) if vlines else 1550
348+
min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1
349+
else:
350+
range_size = max_wl - min_wl
351+
min_wl -= 0.2 * range_size
352+
max_wl += 0.2 * range_size
353+
354+
wls = np.linspace(min_wl, max_wl, num_samples)
355+
return wls.tolist(), list(vlines)

0 commit comments

Comments
 (0)