Skip to content

Commit b588f0f

Browse files
committed
update optimization plots
1 parent 68dced0 commit b588f0f

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from ipywidgets import interactive, IntSlider # type: ignore
44
from typing import List, Optional
55

6+
from . import Parameter
7+
68

79
def plot_circuit(component):
810
"""
@@ -99,7 +101,7 @@ def plot_single_spectrum(
99101

100102

101103
def plot_interactive_spectra(
102-
spectrums: List[List[List[float]]],
104+
spectra: List[List[List[float]]],
103105
wavelengths: List[float],
104106
spectrum_labels: Optional[List[str]] = None,
105107
slider_index: Optional[List[int]] = None,
@@ -110,13 +112,13 @@ def plot_interactive_spectra(
110112
Creates an interactive plot of spectra with a slider to select different indices.
111113
Parameters:
112114
-----------
113-
spectrums : list of list of float
114-
A list of spectrums, where each spectrum is a list of lists of float values, each
115+
spectra : list of list of float
116+
A list of spectra, where each spectrum is a list of lists of float values, each
115117
corresponding to the transmission of a single wavelength.
116118
wavelengths : list of float
117119
A list of wavelength values corresponding to the x-axis of the plot.
118120
slider_index : list of int, optional
119-
A list of indices for the slider. Defaults to range(len(spectrums[0])).
121+
A list of indices for the slider. Defaults to range(len(spectra[0])).
120122
vlines : list of float, optional
121123
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
122124
hlines : list of float, optional
@@ -130,24 +132,24 @@ def plot_interactive_spectra(
130132
- The function uses matplotlib for plotting and ipywidgets for creating the interactive
131133
slider.
132134
- The y-axis limits are fixed based on the global minimum and maximum values across all
133-
spectrums.
135+
spectra.
134136
- Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines`
135137
parameters.
136138
"""
137139
# Calculate global y-limits across all arrays
138-
y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectrums)
139-
y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectrums)
140+
y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectra)
141+
y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectra)
140142

141-
slider_index = slider_index or list(range(len(spectrums[0])))
142-
spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectrums))]
143+
slider_index = slider_index or list(range(len(spectra[0])))
144+
spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectra))]
143145
vlines = vlines or []
144146
hlines = hlines or []
145147

146148
# Function to update the plot
147149
def plot_array(index=0):
148150
plt.close("all")
149151
plt.figure(figsize=(8, 4))
150-
for i, array in enumerate(spectrums):
152+
for i, array in enumerate(spectra):
151153
plt.plot(wavelengths, array[index], lw=2, label=spectrum_labels[i])
152154
for x_val in vlines:
153155
plt.axvline(
@@ -166,6 +168,36 @@ def plot_array(index=0):
166168
plt.show()
167169

168170
slider = IntSlider(
169-
value=0, min=0, max=len(spectrums[0]) - 1, step=1, description="Index"
171+
value=0, min=0, max=len(spectra[0]) - 1, step=1, description="Index"
170172
)
171173
return interactive(plot_array, index=slider)
174+
175+
176+
def plot_parameter_history(parameters: list[Parameter], parameter_history: list[dict]):
177+
"""
178+
Plots the history of specified parameters over iterations.
179+
Args:
180+
parameters (list): A list of parameter objects, each having a 'path' attribute.
181+
parameter_history (list): A list of dictionaries containing parameter values
182+
for each iteration. Each dictionary should be
183+
structured such that the keys correspond to the
184+
first part of the parameter path, and the values
185+
are dictionaries where keys correspond to the
186+
second part of the parameter path.
187+
Returns:
188+
None: This function displays the plots and does not return any value.
189+
"""
190+
191+
for param in parameters:
192+
plt.figure(figsize=(10, 5))
193+
plt.title(f"Parameter {param.path} vs. Iterations")
194+
plt.xlabel("Iterations")
195+
plt.ylabel(param.path)
196+
split_param = param.path.split(",")
197+
plt.plot(
198+
[
199+
parameter_history[i][split_param[0]][split_param[1]]
200+
for i in range(len(parameter_history))
201+
]
202+
)
203+
plt.show()

0 commit comments

Comments
 (0)