@@ -37,11 +37,15 @@ def __init__(self, **kwargs):
3737
3838 # Dictionary to store lines for each subplot
3939 # Key: (row, col), Value: list of lines with their data and kwargs
40- self .subplots = {}
40+ self ._subplots = {}
4141 self ._num_subplots = 0
4242
4343 self ._subplot_matrix = [[None ] * self .ncols for _ in range (self .nrows )]
4444
45+ @property
46+ def subplots (self ):
47+ return self ._subplots
48+
4549 @property
4650 def layers (self ):
4751 layers = []
@@ -88,12 +92,12 @@ def add_tikzfigure(self, **kwargs):
8892
8993 # Store the LinePlot instance by its position for easy access
9094 if label is None :
91- self .subplots [(row , col )] = tikz_figure
95+ self ._subplots [(row , col )] = tikz_figure
9296 else :
93- self .subplots [label ] = tikz_figure
97+ self ._subplots [label ] = tikz_figure
9498 return tikz_figure
9599
96- def add_subplot (self , ** kwargs ):
100+ def add_subplot (self , col : int | None = None , row : int | None = None , label : str | None = None , ** kwargs ):
97101 """
98102 Adds a subplot to the figure.
99103
@@ -103,21 +107,19 @@ def add_subplot(self, **kwargs):
103107 - row (int): Row index for the subplot.
104108 - label (str): Label to identify the subplot.
105109 """
106- col = kwargs .get ("col" , None )
107- row = kwargs .get ("row" , None )
108- label = kwargs .get ("label" , None )
110+
109111
110112 row , col = self .generate_new_rowcol (row , col )
111113
112114 # Initialize the LinePlot for the given subplot position
113- line_plot = lp .LinePlot (** kwargs )
115+ line_plot = lp .LinePlot (col = col , row = row , label = label , ** kwargs )
114116 self ._subplot_matrix [row ][col ] = line_plot
115117
116118 # Store the LinePlot instance by its position for easy access
117119 if label is None :
118- self .subplots [(row , col )] = line_plot
120+ self ._subplots [(row , col )] = line_plot
119121 else :
120- self .subplots [label ] = line_plot
122+ self ._subplots [label ] = line_plot
121123 return line_plot
122124
123125 def savefig (
@@ -159,19 +161,31 @@ def savefig(
159161 if verbose :
160162 print (f"Saved { full_filepath } " )
161163
162- def plot (self , backend = "matplotlib" , show = True , savefig = False , layers = None ):
164+ def plot (self , backend = "matplotlib" , savefig = False , layers = None ):
165+ if backend == "matplotlib" :
166+ return self .plot_matplotlib (savefig = savefig , layers = layers )
167+ elif backend == "plotly" :
168+ return self .plot_plotly (savefig = savefig )
169+ else :
170+ raise ValueError ("Invalid backend" )
171+
172+ def show (self , backend = "matplotlib" ):
163173 if backend == "matplotlib" :
164- return self .plot_matplotlib (show = show , savefig = savefig , layers = layers )
174+ fig , axs = self .plot (backend = "matplotlib" , savefig = False , layers = None )
175+ print ('hmm' )
176+ self ._matplotlib_fig .show ()
177+ plt .show ()
165178 elif backend == "plotly" :
166- self .plot_plotly (show = show , savefig = savefig )
179+ plot = self .plot_plotly (savefig = False )
180+ else :
181+ raise ValueError ("Invalid backend" )
167182
168- def plot_matplotlib (self , show = True , savefig = False , layers = None , usetex = False ):
183+ def plot_matplotlib (self , savefig = False , layers = None , usetex = False ):
169184 """
170185 Generate and optionally display the subplots.
171186
172187 Parameters:
173188 filename (str, optional): Filename to save the figure.
174- show (bool): Whether to display the plot.
175189 """
176190
177191 tex_fonts = plt_utils .setup_tex_fonts (fontsize = self .fontsize , usetex = usetex )
@@ -205,17 +219,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
205219
206220 for (row , col ), subplot in self .subplots .items ():
207221 ax = axes [row ][col ]
208- # print(f"{subplot = }")
209222 subplot .plot_matplotlib (ax , layers = layers )
210223 # ax.set_title(f"Subplot ({row}, {col})")
211224 ax .grid ()
212- # Set caption, labels, etc., if needed
213- # plt.tight_layout()
214225
215- if show :
216- plt .show ()
217- # else:
218- # plt.close()
226+ # Set caption, labels, etc., if needed
219227 self ._plotted = True
220228 self ._matplotlib_fig = fig
221229 self ._matplotlib_axes = axes
@@ -271,8 +279,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
271279 fig .write_image (savefig )
272280
273281 # Show or return the figure
274- if show :
275- fig .show ()
282+ # if show:
283+ # fig.show()
276284 return fig
277285
278286 # Property getters
0 commit comments