@@ -222,16 +222,19 @@ def plot_results(
222222 """
223223 Generate a plot of 'E fraction' vs each input variable from
224224 self.simulate_from_product(...) and variable names at.
225- Optionally, a set of variables can be specified via parameter ' plot_keys: set' .
226- Defaults to plot all available and .
225+ Optionally, a set of variables can be specified via parameter `` plot_keys`` .
226+ Defaults to plot all available and ``relative_airmass`` .
227227 """
228228 start_time = time () # Initialize start time of block
229+ # cast plot_keys to set of strings to plot E fraction against
229230 if plot_keys is None : # default to add relative_airmass
230231 plot_keys = {"relative_airmass" , * self .input_keys }
231232 elif isinstance (plot_keys , str ):
232233 plot_keys = {
233234 plot_keys ,
234- } # cast to set
235+ }
236+ elif not isinstance (plot_keys , set ):
237+ plot_keys = set (plot_keys )
235238
236239 # variable guard: only allow valid keys:
237240 # * self.input_keys & self.time_params
@@ -268,23 +271,26 @@ def plot_results(
268271
269272 # for each axes, plot a relationship
270273 # Case: time
271- for ax , var_name in zip (axs , plot_keys .intersection ({"datetime" })):
274+ for var_name in plot_keys .intersection ({"datetime" }):
275+ ax = next (axs )
272276 ax .set_title (r"$\frac{E_{λ<λ_0}}{E}$ vs. " + var_name )
273277 x = self .datetimes if var_name == "datetime" else None
274278 for _ , row in self .results .iloc [n_inputs :].iterrows ():
275279 ax .scatter (x , row [n_inputs :])
276280 plot_keys .remove (var_name )
277281
278282 # Case: time-dependant variables in plot_keys
279- for ax , var_name in zip (axs , plot_keys .intersection (self .time_params .keys ())):
283+ for var_name in plot_keys .intersection (self .time_params .keys ()):
284+ ax = next (axs )
280285 ax .set_title (r"$\frac{E_{λ<λ_0}}{E}$ vs. " + var_name )
281286 x = self .time_params [var_name ]
282287 for _ , row in self .results .iloc [n_inputs :].iterrows ():
283288 ax .scatter (x , row [n_inputs :])
284289 plot_keys .remove (var_name )
285290
286291 # Case: SPECTRL2 generator input parameters
287- for ax , var_name in zip (axs , plot_keys ):
292+ for var_name in plot_keys :
293+ ax = next (axs )
288294 ax .set_title (r"$\frac{E_{λ<λ_0}}{E}$ vs. " + var_name )
289295 x = self .results [var_name ]
290296 y_df = self .results .iloc [:, n_inputs :]
@@ -297,5 +303,6 @@ def plot_results(
297303 + datetime .now ().strftime ("%Y-%m-%dT%H-%M-%S" )
298304 + ".png"
299305 )
306+ plt .close ()
300307
301308 self .processing_time ["plot_results" ] = time () - start_time
0 commit comments