11#
22# Class for quick plotting of variables from models
33#
4- from __future__ import annotations
54import os
65import numpy as np
76import pybamm
@@ -480,24 +479,24 @@ def reset_axis(self):
480479 ): # pragma: no cover
481480 raise ValueError (f"Axis limits cannot be NaN for variables '{ key } '" )
482481
483- def plot (self , t : float | list [ float ] , dynamic : bool = False ):
482+ def plot (self , t , dynamic = False ):
484483 """Produces a quick plot with the internal states at time t.
485484
486485 Parameters
487486 ----------
488- t : float or list of float
489- Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
487+ t : float
488+ Dimensional time (in 'time_units') at which to plot.
490489 dynamic : bool, optional
491490 Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
492491 If True, creates a dynamic plot with a slider.
493492 """
494493
495494 plt = import_optional_dependency ("matplotlib.pyplot" )
496495 gridspec = import_optional_dependency ("matplotlib.gridspec" )
496+ cm = import_optional_dependency ("matplotlib" , "cm" )
497+ colors = import_optional_dependency ("matplotlib" , "colors" )
497498
498- if not isinstance (t , list ):
499- t = [t ]
500-
499+ t_in_seconds = t * self .time_scaling_factor
501500 self .fig = plt .figure (figsize = self .figsize )
502501
503502 self .gridspec = gridspec .GridSpec (self .n_rows , self .n_cols )
@@ -509,11 +508,6 @@ def plot(self, t: float | list[float], dynamic: bool = False):
509508 # initialize empty handles, to be created only if the appropriate plots are made
510509 solution_handles = []
511510
512- # Generate distinct colors for each time point
513- time_colors = plt .cm .coolwarm (
514- np .linspace (0 , 1 , len (t ))
515- ) # Use a colormap for distinct colors
516-
517511 for k , (key , variable_lists ) in enumerate (self .variables .items ()):
518512 ax = self .fig .add_subplot (self .gridspec [k ])
519513 self .axes .add (key , ax )
@@ -524,17 +518,19 @@ def plot(self, t: float | list[float], dynamic: bool = False):
524518 ax .xaxis .set_major_locator (plt .MaxNLocator (3 ))
525519 self .plots [key ] = defaultdict (dict )
526520 variable_handles = []
527-
521+ # Set labels for the first subplot only (avoid repetition)
528522 if variable_lists [0 ][0 ].dimensions == 0 :
529- # 0D plot: plot as a function of time, indicating multiple times with lines
523+ # 0D plot: plot as a function of time, indicating time t with a line
530524 ax .set_xlabel (f"Time [{ self .time_unit } ]" )
531525 for i , variable_list in enumerate (variable_lists ):
532526 for j , variable in enumerate (variable_list ):
533- linestyle = (
534- self .linestyles [i ]
535- if len (variable_list ) == 1
536- else self .linestyles [j ]
537- )
527+ if len (variable_list ) == 1 :
528+ # single variable -> use linestyle to differentiate model
529+ linestyle = self .linestyles [i ]
530+ else :
531+ # multiple variables -> use linestyle to differentiate
532+ # variables (color differentiates models)
533+ linestyle = self .linestyles [j ]
538534 full_t = self .ts_seconds [i ]
539535 (self .plots [key ][i ][j ],) = ax .plot (
540536 full_t / self .time_scaling_factor ,
@@ -546,104 +542,128 @@ def plot(self, t: float | list[float], dynamic: bool = False):
546542 solution_handles .append (self .plots [key ][i ][0 ])
547543 y_min , y_max = ax .get_ylim ()
548544 ax .set_ylim (y_min , y_max )
549-
550- # Add vertical lines for each time in the list, using different colors for each time
551- for idx , t_single in enumerate (t ):
552- t_in_seconds = t_single * self .time_scaling_factor
553- (self .time_lines [key ],) = ax .plot (
554- [
555- t_in_seconds / self .time_scaling_factor ,
556- t_in_seconds / self .time_scaling_factor ,
557- ],
558- [y_min , y_max ],
559- "--" , # Dashed lines
560- lw = 1.5 ,
561- color = time_colors [idx ], # Different color for each time
562- label = f"t = { t_single :.2f} { self .time_unit } " ,
563- )
564- ax .legend ()
565-
545+ (self .time_lines [key ],) = ax .plot (
546+ [
547+ t_in_seconds / self .time_scaling_factor ,
548+ t_in_seconds / self .time_scaling_factor ,
549+ ],
550+ [y_min , y_max ],
551+ "k--" ,
552+ lw = 1.5 ,
553+ )
566554 elif variable_lists [0 ][0 ].dimensions == 1 :
567- # 1D plot: plot as a function of x at different times
555+ # 1D plot: plot as a function of x at time t
556+ # Read dictionary of spatial variables
568557 spatial_vars = self .spatial_variable_dict [key ]
569558 spatial_var_name = next (iter (spatial_vars .keys ()))
570- ax .set_xlabel (f"{ spatial_var_name } [{ self .spatial_unit } ]" )
571-
572- for idx , t_single in enumerate (t ):
573- t_in_seconds = t_single * self .time_scaling_factor
574-
575- for i , variable_list in enumerate (variable_lists ):
576- for j , variable in enumerate (variable_list ):
577- linestyle = (
578- self .linestyles [i ]
579- if len (variable_list ) == 1
580- else self .linestyles [j ]
581- )
582- (self .plots [key ][i ][j ],) = ax .plot (
583- self .first_spatial_variable [key ],
584- variable (t_in_seconds , ** spatial_vars ),
585- color = time_colors [idx ], # Different color for each time
586- linestyle = linestyle ,
587- label = f"t = { t_single :.2f} { self .time_unit } " , # Add time label
588- zorder = 10 ,
589- )
590- variable_handles .append (self .plots [key ][0 ][j ])
591- solution_handles .append (self .plots [key ][i ][0 ])
592-
593- # Add a legend to indicate which plot corresponds to which time
594- ax .legend ()
595-
559+ ax .set_xlabel (
560+ f"{ spatial_var_name } [{ self .spatial_unit } ]" ,
561+ )
562+ for i , variable_list in enumerate (variable_lists ):
563+ for j , variable in enumerate (variable_list ):
564+ if len (variable_list ) == 1 :
565+ # single variable -> use linestyle to differentiate model
566+ linestyle = self .linestyles [i ]
567+ else :
568+ # multiple variables -> use linestyle to differentiate
569+ # variables (color differentiates models)
570+ linestyle = self .linestyles [j ]
571+ (self .plots [key ][i ][j ],) = ax .plot (
572+ self .first_spatial_variable [key ],
573+ variable (t_in_seconds , ** spatial_vars ),
574+ color = self .colors [i ],
575+ linestyle = linestyle ,
576+ zorder = 10 ,
577+ )
578+ variable_handles .append (self .plots [key ][0 ][j ])
579+ solution_handles .append (self .plots [key ][i ][0 ])
580+ # add lines for boundaries between subdomains
581+ for boundary in variable_lists [0 ][0 ].internal_boundaries :
582+ boundary_scaled = boundary * self .spatial_factor
583+ ax .axvline (boundary_scaled , color = "0.5" , lw = 1 , zorder = 0 )
596584 elif variable_lists [0 ][0 ].dimensions == 2 :
597- # 2D plot: superimpose plots at different times
585+ # Read dictionary of spatial variables
598586 spatial_vars = self .spatial_variable_dict [key ]
587+ # there can only be one entry in the variable list
599588 variable = variable_lists [0 ][0 ]
600-
601- for t_single in t :
602- t_in_seconds = t_single * self .time_scaling_factor
589+ # different order based on whether the domains are x-r, x-z or y-z, etc
590+ if self .x_first_and_y_second [key ] is False :
591+ x_name = list (spatial_vars .keys ())[1 ][0 ]
592+ y_name = next (iter (spatial_vars .keys ()))[0 ]
593+ x = self .second_spatial_variable [key ]
594+ y = self .first_spatial_variable [key ]
595+ var = variable (t_in_seconds , ** spatial_vars )
596+ else :
597+ x_name = next (iter (spatial_vars .keys ()))[0 ]
598+ y_name = list (spatial_vars .keys ())[1 ][0 ]
603599 x = self .first_spatial_variable [key ]
604600 y = self .second_spatial_variable [key ]
605601 var = variable (t_in_seconds , ** spatial_vars ).T
606-
607- ax .set_xlabel (
608- f"{ next (iter (spatial_vars .keys ()))[0 ]} [{ self .spatial_unit } ]"
609- )
610- ax .set_ylabel (
611- f"{ list (spatial_vars .keys ())[1 ][0 ]} [{ self .spatial_unit } ]"
602+ ax .set_xlabel (f"{ x_name } [{ self .spatial_unit } ]" )
603+ ax .set_ylabel (f"{ y_name } [{ self .spatial_unit } ]" )
604+ vmin , vmax = self .variable_limits [key ]
605+ # store the plot and the var data (for testing) as cant access
606+ # z data from QuadMesh or QuadContourSet object
607+ if self .is_y_z [key ] is True :
608+ self .plots [key ][0 ][0 ] = ax .pcolormesh (
609+ x ,
610+ y ,
611+ var ,
612+ vmin = vmin ,
613+ vmax = vmax ,
614+ shading = self .shading ,
612615 )
613- vmin , vmax = self .variable_limits [key ]
614-
615- # Use contourf and colorbars to represent the values
616- contour_plot = ax .contourf (
617- x , y , var , levels = 100 , vmin = vmin , vmax = vmax , cmap = "coolwarm"
616+ else :
617+ self .plots [key ][0 ][0 ] = ax .contourf (
618+ x , y , var , levels = 100 , vmin = vmin , vmax = vmax
618619 )
619- self .plots [key ][0 ][0 ] = contour_plot
620- self .colorbars [key ] = self .fig .colorbar (contour_plot , ax = ax )
621-
622- self .plots [key ][0 ][1 ] = var
623-
624- ax .set_title (f"t = { t_single :.2f} { self .time_unit } " )
620+ self .plots [key ][0 ][1 ] = var
621+ if vmin is None and vmax is None :
622+ vmin = ax_min (var )
623+ vmax = ax_max (var )
624+ self .colorbars [key ] = self .fig .colorbar (
625+ cm .ScalarMappable (colors .Normalize (vmin = vmin , vmax = vmax )),
626+ ax = ax ,
627+ )
628+ # Set either y label or legend entries
629+ if len (key ) == 1 :
630+ title = split_long_string (key [0 ])
631+ ax .set_title (title , fontsize = "medium" )
632+ else :
633+ ax .legend (
634+ variable_handles ,
635+ [split_long_string (s , 6 ) for s in key ],
636+ bbox_to_anchor = (0.5 , 1 ),
637+ loc = "lower center" ,
638+ )
625639
626- # Set global legend if there are multiple models
640+ # Set global legend
627641 if len (self .labels ) > 1 :
628642 fig_legend = self .fig .legend (
629643 solution_handles , self .labels , loc = "lower right"
630644 )
645+ # Get the position of the top of the legend in relative figure units
646+ # There may be a better way ...
647+ try :
648+ legend_top_inches = fig_legend .get_window_extent (
649+ renderer = self .fig .canvas .get_renderer ()
650+ ).get_points ()[1 , 1 ]
651+ fig_height_inches = (self .fig .get_size_inches () * self .fig .dpi )[1 ]
652+ legend_top = legend_top_inches / fig_height_inches
653+ except AttributeError : # pragma: no cover
654+ # When testing the examples we set the matplotlib backend to "Template"
655+ # which means that the above code doesn't work. Since this is just for
656+ # that particular test we can just skip it
657+ legend_top = 0
631658 else :
632- fig_legend = None
659+ legend_top = 0
633660
634- # Fix layout for sliders if dynamic
661+ # Fix layout
635662 if dynamic :
636663 slider_top = 0.05
637664 else :
638665 slider_top = 0
639- bottom = max (
640- fig_legend .get_window_extent (
641- renderer = self .fig .canvas .get_renderer ()
642- ).get_points ()[1 , 1 ]
643- if fig_legend
644- else 0 ,
645- slider_top ,
646- )
666+ bottom = max (legend_top , slider_top )
647667 self .gridspec .tight_layout (self .fig , rect = [0 , bottom , 1 , 1 ])
648668
649669 def dynamic_plot (self , show_plot = True , step = None ):
0 commit comments