11import re
2- import numpy as np # type: ignore
2+ from functools import reduce
3+ from operator import getitem
4+ from typing import Dict , List , Optional , Set , Tuple
5+
36import iklayout # type: ignore
47import matplotlib .pyplot as plt # type: ignore
5- import plotly . graph_objects as go # type: ignore
6- from typing import List , Optional , Tuple , Dict , Set
8+ import numpy as np # type: ignore
9+ import plotly . graph_objects as go # type: ignore
710
8- from . import Parameter , StatementDictionary , StatementValidationDictionary , StatementValidation , Computation
11+ from . import Computation , Parameter , StatementDictionary , StatementValidation , StatementValidationDictionary
912
1013
1114def plot_circuit (component ):
@@ -56,9 +59,7 @@ def plot_constraints(
5659 labels: List of labels for each constraint value.
5760 """
5861
59- constraints_labels = constraints_labels or [
60- f"Constraint { i } " for i in range (len (constraints [0 ]))
61- ]
62+ constraints_labels = constraints_labels or [f"Constraint { i } " for i in range (len (constraints [0 ]))]
6263 iterations = iterations or list (range (len (constraints [0 ])))
6364
6465 plt .clf ()
@@ -92,13 +93,9 @@ def plot_single_spectrum(
9293 plt .ylabel ("Losses" )
9394 plt .plot (wavelengths , spectrum )
9495 for x_val in vlines :
95- plt .axvline (
96- x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )"
97- ) # Add vertical line
96+ plt .axvline (x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )" ) # Add vertical line
9897 for y_val in hlines :
99- plt .axhline (
100- y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )"
101- ) # Add vertical line
98+ plt .axhline (y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )" ) # Add vertical line
10299 return plt .gcf ()
103100
104101
@@ -109,7 +106,7 @@ def plot_interactive_spectra(
109106 vlines : Optional [List [float ]] = None ,
110107 hlines : Optional [List [float ]] = None ,
111108):
112- """"
109+ """ "
113110 Creates an interactive plot of spectra with a slider to select different indices.
114111 Parameters:
115112 -----------
@@ -131,7 +128,7 @@ def plot_interactive_spectra(
131128 vlines = []
132129 if hlines is None :
133130 hlines = []
134-
131+
135132 # Adjust y-axis range
136133 all_vals = [val for spec in spectra for iteration in spec for val in iteration ]
137134 y_min = min (all_vals )
@@ -143,49 +140,28 @@ def plot_interactive_spectra(
143140 # Create hlines and vlines
144141 shapes = []
145142 for xv in vlines :
146- shapes .append (dict (
147- type = "line" ,
148- xref = "x" , x0 = xv , x1 = xv ,
149- yref = "paper" , y0 = 0 , y1 = 1 ,
150- line = dict (color = "red" , dash = "dash" )
151- ))
143+ shapes .append (
144+ dict (type = "line" , xref = "x" , x0 = xv , x1 = xv , yref = "paper" , y0 = 0 , y1 = 1 , line = dict (color = "red" , dash = "dash" ))
145+ )
152146 for yh in hlines :
153- shapes .append (dict (
154- type = "line" ,
155- xref = "paper" , x0 = 0 , x1 = 1 ,
156- yref = "y" , y0 = yh , y1 = yh ,
157- line = dict (color = "red" , dash = "dash" )
158- ))
159-
160-
147+ shapes .append (
148+ dict (type = "line" , xref = "paper" , x0 = 0 , x1 = 1 , yref = "y" , y0 = yh , y1 = yh , line = dict (color = "red" , dash = "dash" ))
149+ )
150+
161151 # Create frames for each index
162152 slider_index = list (range (len (spectra [0 ])))
163153 fig = go .Figure ()
164154
165155 # Build initial figure for immediate display
166156 init_idx = slider_index [0 ]
167157 for i , spec in enumerate (spectra ):
168- fig .add_trace (
169- go .Scatter (
170- x = wavelengths ,
171- y = spec [init_idx ],
172- mode = "lines" ,
173- name = spectrum_labels [i ]
174- )
175- )
158+ fig .add_trace (go .Scatter (x = wavelengths , y = spec [init_idx ], mode = "lines" , name = spectrum_labels [i ]))
176159 # Build frames for animation
177160 frames = []
178161 for idx in slider_index :
179162 frame_data = []
180163 for i , spec in enumerate (spectra ):
181- frame_data .append (
182- go .Scatter (
183- x = wavelengths ,
184- y = spec [idx ],
185- mode = "lines" ,
186- name = spectrum_labels [i ]
187- )
188- )
164+ frame_data .append (go .Scatter (x = wavelengths , y = spec [idx ], mode = "lines" , name = spectrum_labels [i ]))
189165 frames .append (
190166 go .Frame (
191167 data = frame_data ,
@@ -195,30 +171,22 @@ def plot_interactive_spectra(
195171
196172 fig .frames = frames
197173
198-
199174 # Create transition steps
200175 steps = []
201176 for idx in slider_index :
202- steps .append (dict (
203- method = "animate" ,
204- args = [
205- [str (idx )],
206- {
207- "mode" : "immediate" ,
208- "frame" : {"duration" : 0 , "redraw" : True },
209- "transition" : {"duration" : 0 }
210- }
211- ],
212- label = str (idx ),
213- ))
177+ steps .append (
178+ dict (
179+ method = "animate" ,
180+ args = [
181+ [str (idx )],
182+ {"mode" : "immediate" , "frame" : {"duration" : 0 , "redraw" : True }, "transition" : {"duration" : 0 }},
183+ ],
184+ label = str (idx ),
185+ )
186+ )
214187
215188 # Create the slider
216- sliders = [dict (
217- active = 0 ,
218- currentvalue = {"prefix" : "Index: " },
219- pad = {"t" : 50 },
220- steps = steps
221- )]
189+ sliders = [dict (active = 0 , currentvalue = {"prefix" : "Index: " }, pad = {"t" : 50 }, steps = steps )]
222190
223191 # Create the layout
224192 fig .update_layout (
@@ -253,25 +221,28 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
253221 plt .xlabel ("Iterations" )
254222 plt .ylabel (param .path )
255223 split_param = param .path .split ("," )
256- plt .plot (
257- [
258- parameter_history [i ][split_param [0 ]][split_param [1 ]]
259- for i in range (len (parameter_history ))
260- ]
261- )
224+ plt .plot ([reduce (getitem , split_param , parameter_history [i ]) for i in range (len (parameter_history ))])
262225 plt .show ()
263226
264227
265- def print_statements (statements : StatementDictionary , validation : Optional [StatementValidationDictionary ] = None , only_formalized : bool = False ):
228+ def print_statements (
229+ statements : StatementDictionary ,
230+ validation : Optional [StatementValidationDictionary ] = None ,
231+ only_formalized : bool = False ,
232+ ):
266233 """
267234 Print a list of statements in nice readable format.
268235 """
269236
270237 validation = StatementValidationDictionary (
271- cost_functions = (validation .cost_functions if validation is not None else None ) or [StatementValidation ()]* len (statements .cost_functions or []),
272- parameter_constraints = (validation .parameter_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .parameter_constraints or []),
273- structure_constraints = (validation .structure_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .structure_constraints or []),
274- unformalizable_statements = (validation .unformalizable_statements if validation is not None else None ) or [StatementValidation ()]* len (statements .unformalizable_statements or [])
238+ cost_functions = (validation .cost_functions if validation is not None else None )
239+ or [StatementValidation ()] * len (statements .cost_functions or []),
240+ parameter_constraints = (validation .parameter_constraints if validation is not None else None )
241+ or [StatementValidation ()] * len (statements .parameter_constraints or []),
242+ structure_constraints = (validation .structure_constraints if validation is not None else None )
243+ or [StatementValidation ()] * len (statements .structure_constraints or []),
244+ unformalizable_statements = (validation .unformalizable_statements if validation is not None else None )
245+ or [StatementValidation ()] * len (statements .unformalizable_statements or []),
275246 )
276247
277248 if len (validation .cost_functions or []) != len (statements .cost_functions or []):
@@ -299,8 +270,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
299270 if computation is not None :
300271 args_str = ", " .join (
301272 [
302- f"{ argname } ="
303- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
273+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
304274 for argname , argvalue in computation .arguments .items ()
305275 ]
306276 )
@@ -326,8 +296,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
326296 if computation is not None :
327297 args_str = ", " .join (
328298 [
329- f"{ argname } ="
330- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
299+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
331300 for argname , argvalue in computation .arguments .items ()
332301 ]
333302 )
@@ -382,9 +351,7 @@ def _str_units_to_float(str_units: str) -> float:
382351 return float (numeric_value * unit_conversions [unit ])
383352
384353
385- def get_wavelengths_to_plot (
386- statements : StatementDictionary , num_samples : int = 100
387- ) -> Tuple [List [float ], List [float ]]:
354+ def get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int = 100 ) -> Tuple [List [float ], List [float ]]:
388355 """
389356 Get the wavelengths to plot based on the statements.
390357
@@ -401,10 +368,16 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
401368 continue
402369 if "wavelengths" in comp .arguments :
403370 vlines = vlines | {
404- _str_units_to_float (wl ) for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else []) if isinstance (wl , str )
371+ _str_units_to_float (wl )
372+ for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else [])
373+ if isinstance (wl , str )
405374 }
406375 if "wavelength_range" in comp .arguments :
407- if isinstance (comp .arguments ["wavelength_range" ], list ) and len (comp .arguments ["wavelength_range" ]) == 2 and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ]):
376+ if (
377+ isinstance (comp .arguments ["wavelength_range" ], list )
378+ and len (comp .arguments ["wavelength_range" ]) == 2
379+ and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ])
380+ ):
408381 min_wl = min (min_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][0 ]))
409382 max_wl = max (max_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][1 ]))
410383 return min_wl , max_wl , vlines
0 commit comments