|
| 1 | +import re |
| 2 | +import numpy as np |
1 | 3 | import iklayout # type: ignore |
2 | 4 | import matplotlib.pyplot as plt # type: ignore |
3 | 5 | from ipywidgets import interactive, IntSlider # type: ignore |
@@ -301,3 +303,53 @@ def print_statements(statements: StatementDictionary, validation: Optional[State |
301 | 303 | print("Statement:", unf_stmt.text) |
302 | 304 | print("Formalization: UNFORMALIZABLE") |
303 | 305 | print("\n-----------------------------------\n") |
| 306 | + |
| 307 | + |
| 308 | +def _str_units_to_float(str_units: str) -> float: |
| 309 | + unit_conversions = { |
| 310 | + "nm": 1e-3, |
| 311 | + "um": 1, |
| 312 | + "mm": 1e3, |
| 313 | + "m": 1e6, |
| 314 | + } |
| 315 | + match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units) |
| 316 | + numeric_value = float(match.group(1)) |
| 317 | + unit = match.group(2) |
| 318 | + return float(numeric_value * unit_conversions[unit]) |
| 319 | + |
| 320 | + |
| 321 | +def get_wavelengths_to_plot( |
| 322 | + statements: StatementDictionary, num_samples: int = 100 |
| 323 | +) -> tuple[list[float], list[float]]: |
| 324 | + """ |
| 325 | + Get the wavelengths to plot based on the statements. |
| 326 | +
|
| 327 | + Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra. |
| 328 | + """ |
| 329 | + |
| 330 | + min_wl = np.inf |
| 331 | + max_wl = -np.inf |
| 332 | + vlines = set() |
| 333 | + for stmt in statements.parameter_constraints + statements.cost_functions: |
| 334 | + if stmt.formalization is not None and stmt.formalization.mapping is not None: |
| 335 | + for comp in stmt.formalization.mapping.values(): |
| 336 | + if "wavelengths" in comp.arguments: |
| 337 | + vlines = vlines | { |
| 338 | + _str_units_to_float(wl) for wl in comp.arguments["wavelengths"] |
| 339 | + } |
| 340 | + if "wavelength_range" in comp.arguments: |
| 341 | + min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0])) |
| 342 | + max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1])) |
| 343 | + if vlines: |
| 344 | + min_wl = min(min_wl, min(vlines)) |
| 345 | + max_wl = max(max_wl, max(vlines)) |
| 346 | + if min_wl >= max_wl: |
| 347 | + avg_wl = sum(vlines) / len(vlines) if vlines else 1550 |
| 348 | + min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1 |
| 349 | + else: |
| 350 | + range_size = max_wl - min_wl |
| 351 | + min_wl -= 0.2 * range_size |
| 352 | + max_wl += 0.2 * range_size |
| 353 | + |
| 354 | + wls = np.linspace(min_wl, max_wl, num_samples) |
| 355 | + return wls.tolist(), list(vlines) |
0 commit comments