Skip to content

Commit eb68dbf

Browse files
committed
Get wavelengths to plot
1 parent 12e8f1e commit eb68dbf

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import re
2+
import numpy as np # type: ignore
13
import iklayout # type: ignore
24
import matplotlib.pyplot as plt # type: ignore
35
from ipywidgets import interactive, IntSlider # type: ignore
46
from typing import List, Optional
57

6-
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation
8+
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
79

810

911
def plot_circuit(component):
@@ -301,3 +303,65 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
301303
print("Statement:", unf_stmt.text)
302304
print("Formalization: UNFORMALIZABLE")
303305
print("\n-----------------------------------\n")
306+
307+
308+
def _str_units_to_float(str_units: str) -> float:
309+
unit_conversions = {
310+
"nm": 1e-3,
311+
"um": 1,
312+
"mm": 1e3,
313+
"m": 1e6,
314+
}
315+
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units)
316+
numeric_value = float(match.group(1) if match else 1.55)
317+
unit = match.group(2) if match else "um"
318+
return float(numeric_value * unit_conversions[unit])
319+
320+
321+
def get_wavelengths_to_plot(
322+
statements: StatementDictionary, num_samples: int = 100
323+
) -> tuple[list[float], list[float]]:
324+
"""
325+
Get the wavelengths to plot based on the statements.
326+
327+
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
328+
"""
329+
330+
min_wl = np.inf
331+
max_wl = -np.inf
332+
vlines: set = set()
333+
334+
def update_wavelengths(mapping: dict[str, Optional[Computation]]):
335+
for comp in mapping.values():
336+
if comp is None:
337+
continue
338+
if "wavelengths" in comp.arguments:
339+
vlines = vlines | {
340+
_str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str)
341+
}
342+
if "wavelength_range" in comp.arguments:
343+
if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]):
344+
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
345+
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
346+
347+
for cost_stmt in statements.cost_functions or []:
348+
if cost_stmt.formalization is not None and cost_stmt.formalization.mapping is not None:
349+
update_wavelengths(cost_stmt.formalization.mapping)
350+
351+
for param_stmt in statements.parameter_constraints or []:
352+
if param_stmt.formalization is not None and param_stmt.formalization.mapping is not None:
353+
update_wavelengths(param_stmt.formalization.mapping)
354+
355+
if vlines:
356+
min_wl = min(min_wl, min(vlines))
357+
max_wl = max(max_wl, max(vlines))
358+
if min_wl >= max_wl:
359+
avg_wl = sum(vlines) / len(vlines) if vlines else 1550
360+
min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1
361+
else:
362+
range_size = max_wl - min_wl
363+
min_wl -= 0.2 * range_size
364+
max_wl += 0.2 * range_size
365+
366+
wls = np.linspace(min_wl, max_wl, num_samples)
367+
return wls.tolist(), list(vlines)

0 commit comments

Comments
 (0)