|
| 1 | +import re |
| 2 | +import numpy as np # type: ignore |
1 | 3 | import iklayout # type: ignore |
2 | 4 | import matplotlib.pyplot as plt # type: ignore |
3 | 5 | from ipywidgets import interactive, IntSlider # type: ignore |
4 | 6 | from typing import List, Optional |
5 | 7 |
|
6 | | -from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation |
| 8 | +from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation |
7 | 9 |
|
8 | 10 |
|
9 | 11 | def plot_circuit(component): |
@@ -301,3 +303,65 @@ def print_statements(statements: StatementDictionary, validation: Optional[State |
301 | 303 | print("Statement:", unf_stmt.text) |
302 | 304 | print("Formalization: UNFORMALIZABLE") |
303 | 305 | print("\n-----------------------------------\n") |
| 306 | + |
| 307 | + |
| 308 | +def _str_units_to_float(str_units: str) -> float: |
| 309 | + unit_conversions = { |
| 310 | + "nm": 1e-3, |
| 311 | + "um": 1, |
| 312 | + "mm": 1e3, |
| 313 | + "m": 1e6, |
| 314 | + } |
| 315 | + match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units) |
| 316 | + numeric_value = float(match.group(1) if match else 1.55) |
| 317 | + unit = match.group(2) if match else "um" |
| 318 | + return float(numeric_value * unit_conversions[unit]) |
| 319 | + |
| 320 | + |
| 321 | +def get_wavelengths_to_plot( |
| 322 | + statements: StatementDictionary, num_samples: int = 100 |
| 323 | +) -> tuple[list[float], list[float]]: |
| 324 | + """ |
| 325 | + Get the wavelengths to plot based on the statements. |
| 326 | +
|
| 327 | + Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra. |
| 328 | + """ |
| 329 | + |
| 330 | + min_wl = np.inf |
| 331 | + max_wl = -np.inf |
| 332 | + vlines: set = set() |
| 333 | + |
| 334 | + def update_wavelengths(mapping: dict[str, Optional[Computation]]): |
| 335 | + for comp in mapping.values(): |
| 336 | + if comp is None: |
| 337 | + continue |
| 338 | + if "wavelengths" in comp.arguments: |
| 339 | + vlines = vlines | { |
| 340 | + _str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str) |
| 341 | + } |
| 342 | + if "wavelength_range" in comp.arguments: |
| 343 | + if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]): |
| 344 | + min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0])) |
| 345 | + max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1])) |
| 346 | + |
| 347 | + for cost_stmt in statements.cost_functions or []: |
| 348 | + if cost_stmt.formalization is not None and cost_stmt.formalization.mapping is not None: |
| 349 | + update_wavelengths(cost_stmt.formalization.mapping) |
| 350 | + |
| 351 | + for param_stmt in statements.parameter_constraints or []: |
| 352 | + if param_stmt.formalization is not None and param_stmt.formalization.mapping is not None: |
| 353 | + update_wavelengths(param_stmt.formalization.mapping) |
| 354 | + |
| 355 | + if vlines: |
| 356 | + min_wl = min(min_wl, min(vlines)) |
| 357 | + max_wl = max(max_wl, max(vlines)) |
| 358 | + if min_wl >= max_wl: |
| 359 | + avg_wl = sum(vlines) / len(vlines) if vlines else 1550 |
| 360 | + min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1 |
| 361 | + else: |
| 362 | + range_size = max_wl - min_wl |
| 363 | + min_wl -= 0.2 * range_size |
| 364 | + max_wl += 0.2 * range_size |
| 365 | + |
| 366 | + wls = np.linspace(min_wl, max_wl, num_samples) |
| 367 | + return wls.tolist(), list(vlines) |
0 commit comments