Skip to content

Commit d20a64a

Browse files
committed
hotfix plot
1 parent 43f00ad commit d20a64a

File tree

1 file changed

+57
-84
lines changed

1 file changed

+57
-84
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 57 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import re
2-
import numpy as np # type: ignore
2+
from functools import reduce
3+
from operator import getitem
4+
from typing import Dict, List, Optional, Set, Tuple
5+
36
import iklayout # type: ignore
47
import matplotlib.pyplot as plt # type: ignore
5-
import plotly.graph_objects as go # type: ignore
6-
from typing import List, Optional, Tuple, Dict, Set
8+
import numpy as np # type: ignore
9+
import plotly.graph_objects as go # type: ignore
710

8-
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
11+
from . import Computation, Parameter, StatementDictionary, StatementValidation, StatementValidationDictionary
912

1013

1114
def plot_circuit(component):
@@ -56,9 +59,7 @@ def plot_constraints(
5659
labels: List of labels for each constraint value.
5760
"""
5861

59-
constraints_labels = constraints_labels or [
60-
f"Constraint {i}" for i in range(len(constraints[0]))
61-
]
62+
constraints_labels = constraints_labels or [f"Constraint {i}" for i in range(len(constraints[0]))]
6263
iterations = iterations or list(range(len(constraints[0])))
6364

6465
plt.clf()
@@ -92,13 +93,9 @@ def plot_single_spectrum(
9293
plt.ylabel("Losses")
9394
plt.plot(wavelengths, spectrum)
9495
for x_val in vlines:
95-
plt.axvline(
96-
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
97-
) # Add vertical line
96+
plt.axvline(x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})") # Add vertical line
9897
for y_val in hlines:
99-
plt.axhline(
100-
y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
101-
) # Add vertical line
98+
plt.axhline(y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})") # Add vertical line
10299
return plt.gcf()
103100

104101

@@ -109,7 +106,7 @@ def plot_interactive_spectra(
109106
vlines: Optional[List[float]] = None,
110107
hlines: Optional[List[float]] = None,
111108
):
112-
""""
109+
""" "
113110
Creates an interactive plot of spectra with a slider to select different indices.
114111
Parameters:
115112
-----------
@@ -131,7 +128,7 @@ def plot_interactive_spectra(
131128
vlines = []
132129
if hlines is None:
133130
hlines = []
134-
131+
135132
# Adjust y-axis range
136133
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
137134
y_min = min(all_vals)
@@ -143,49 +140,28 @@ def plot_interactive_spectra(
143140
# Create hlines and vlines
144141
shapes = []
145142
for xv in vlines:
146-
shapes.append(dict(
147-
type="line",
148-
xref="x", x0=xv, x1=xv,
149-
yref="paper", y0=0, y1=1,
150-
line=dict(color="red", dash="dash")
151-
))
143+
shapes.append(
144+
dict(type="line", xref="x", x0=xv, x1=xv, yref="paper", y0=0, y1=1, line=dict(color="red", dash="dash"))
145+
)
152146
for yh in hlines:
153-
shapes.append(dict(
154-
type="line",
155-
xref="paper", x0=0, x1=1,
156-
yref="y", y0=yh, y1=yh,
157-
line=dict(color="red", dash="dash")
158-
))
159-
160-
147+
shapes.append(
148+
dict(type="line", xref="paper", x0=0, x1=1, yref="y", y0=yh, y1=yh, line=dict(color="red", dash="dash"))
149+
)
150+
161151
# Create frames for each index
162152
slider_index = list(range(len(spectra[0])))
163153
fig = go.Figure()
164154

165155
# Build initial figure for immediate display
166156
init_idx = slider_index[0]
167157
for i, spec in enumerate(spectra):
168-
fig.add_trace(
169-
go.Scatter(
170-
x=wavelengths,
171-
y=spec[init_idx],
172-
mode="lines",
173-
name=spectrum_labels[i]
174-
)
175-
)
158+
fig.add_trace(go.Scatter(x=wavelengths, y=spec[init_idx], mode="lines", name=spectrum_labels[i]))
176159
# Build frames for animation
177160
frames = []
178161
for idx in slider_index:
179162
frame_data = []
180163
for i, spec in enumerate(spectra):
181-
frame_data.append(
182-
go.Scatter(
183-
x=wavelengths,
184-
y=spec[idx],
185-
mode="lines",
186-
name=spectrum_labels[i]
187-
)
188-
)
164+
frame_data.append(go.Scatter(x=wavelengths, y=spec[idx], mode="lines", name=spectrum_labels[i]))
189165
frames.append(
190166
go.Frame(
191167
data=frame_data,
@@ -195,30 +171,22 @@ def plot_interactive_spectra(
195171

196172
fig.frames = frames
197173

198-
199174
# Create transition steps
200175
steps = []
201176
for idx in slider_index:
202-
steps.append(dict(
203-
method="animate",
204-
args=[
205-
[str(idx)],
206-
{
207-
"mode": "immediate",
208-
"frame": {"duration": 0, "redraw": True},
209-
"transition": {"duration": 0}
210-
}
211-
],
212-
label=str(idx),
213-
))
177+
steps.append(
178+
dict(
179+
method="animate",
180+
args=[
181+
[str(idx)],
182+
{"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}},
183+
],
184+
label=str(idx),
185+
)
186+
)
214187

215188
# Create the slider
216-
sliders = [dict(
217-
active=0,
218-
currentvalue={"prefix": "Index: "},
219-
pad={"t": 50},
220-
steps=steps
221-
)]
189+
sliders = [dict(active=0, currentvalue={"prefix": "Index: "}, pad={"t": 50}, steps=steps)]
222190

223191
# Create the layout
224192
fig.update_layout(
@@ -253,25 +221,28 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
253221
plt.xlabel("Iterations")
254222
plt.ylabel(param.path)
255223
split_param = param.path.split(",")
256-
plt.plot(
257-
[
258-
parameter_history[i][split_param[0]][split_param[1]]
259-
for i in range(len(parameter_history))
260-
]
261-
)
224+
plt.plot([reduce(getitem, split_param, parameter_history[i]) for i in range(len(parameter_history))])
262225
plt.show()
263226

264227

265-
def print_statements(statements: StatementDictionary, validation: Optional[StatementValidationDictionary] = None, only_formalized: bool = False):
228+
def print_statements(
229+
statements: StatementDictionary,
230+
validation: Optional[StatementValidationDictionary] = None,
231+
only_formalized: bool = False,
232+
):
266233
"""
267234
Print a list of statements in nice readable format.
268235
"""
269236

270237
validation = StatementValidationDictionary(
271-
cost_functions=(validation.cost_functions if validation is not None else None) or [StatementValidation()]*len(statements.cost_functions or []),
272-
parameter_constraints=(validation.parameter_constraints if validation is not None else None) or [StatementValidation()]*len(statements.parameter_constraints or []),
273-
structure_constraints=(validation.structure_constraints if validation is not None else None) or [StatementValidation()]*len(statements.structure_constraints or []),
274-
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None) or [StatementValidation()]*len(statements.unformalizable_statements or [])
238+
cost_functions=(validation.cost_functions if validation is not None else None)
239+
or [StatementValidation()] * len(statements.cost_functions or []),
240+
parameter_constraints=(validation.parameter_constraints if validation is not None else None)
241+
or [StatementValidation()] * len(statements.parameter_constraints or []),
242+
structure_constraints=(validation.structure_constraints if validation is not None else None)
243+
or [StatementValidation()] * len(statements.structure_constraints or []),
244+
unformalizable_statements=(validation.unformalizable_statements if validation is not None else None)
245+
or [StatementValidation()] * len(statements.unformalizable_statements or []),
275246
)
276247

277248
if len(validation.cost_functions or []) != len(statements.cost_functions or []):
@@ -299,8 +270,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
299270
if computation is not None:
300271
args_str = ", ".join(
301272
[
302-
f"{argname}="
303-
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
273+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
304274
for argname, argvalue in computation.arguments.items()
305275
]
306276
)
@@ -326,8 +296,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
326296
if computation is not None:
327297
args_str = ", ".join(
328298
[
329-
f"{argname}="
330-
+ (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
299+
f"{argname}=" + (f"'{argvalue}'" if isinstance(argvalue, str) else str(argvalue))
331300
for argname, argvalue in computation.arguments.items()
332301
]
333302
)
@@ -382,9 +351,7 @@ def _str_units_to_float(str_units: str) -> float:
382351
return float(numeric_value * unit_conversions[unit])
383352

384353

385-
def get_wavelengths_to_plot(
386-
statements: StatementDictionary, num_samples: int = 100
387-
) -> Tuple[List[float], List[float]]:
354+
def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]:
388355
"""
389356
Get the wavelengths to plot based on the statements.
390357
@@ -401,10 +368,16 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
401368
continue
402369
if "wavelengths" in comp.arguments:
403370
vlines = vlines | {
404-
_str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str)
371+
_str_units_to_float(wl)
372+
for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else [])
373+
if isinstance(wl, str)
405374
}
406375
if "wavelength_range" in comp.arguments:
407-
if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]):
376+
if (
377+
isinstance(comp.arguments["wavelength_range"], list)
378+
and len(comp.arguments["wavelength_range"]) == 2
379+
and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"])
380+
):
408381
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
409382
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
410383
return min_wl, max_wl, vlines

0 commit comments

Comments
 (0)