@@ -33,6 +33,7 @@ def __init__(self, **kwargs):
3333 self ._width = kwargs .get ("width" , "17cm" )
3434 self ._ratio = kwargs .get ("ratio" , "golden" )
3535 self ._gridspec_kw = kwargs .get ("gridspec_kw" , {"wspace" : 0.08 , "hspace" : 0.1 })
36+ self ._plotted = False
3637
3738 # Dictionary to store lines for each subplot
3839 # Key: (row, col), Value: list of lines with their data and kwargs
@@ -126,6 +127,7 @@ def savefig(
126127 layers = None ,
127128 layer_by_layer = False ,
128129 verbose = False ,
130+ plot = True ,
129131 ):
130132 filename_no_extension , extension = os .path .splitext (filename )
131133 if backend == "matplotlib" :
@@ -145,10 +147,16 @@ def savefig(
145147 full_filepath = filename
146148 else :
147149 full_filepath = f"{ filename_no_extension } _{ layers } .{ extension } "
148- fig , axs = self .plot (
149- show = False , backend = "matplotlib" , savefig = True , layers = layers
150- )
151- fig .savefig (full_filepath )
150+ # print(f"Save to {full_filepath}")
151+ if self ._plotted :
152+ self ._matplotlib_fig .savefig (full_filepath )
153+ else :
154+
155+ fig , axs = self .plot (
156+ show = False , backend = "matplotlib" , savefig = True , layers = layers
157+ )
158+ # print('done plotting')
159+ fig .savefig (full_filepath )
152160 if verbose :
153161 print (f"Saved { full_filepath } " )
154162
@@ -158,16 +166,15 @@ def plot(self, backend="matplotlib", show=True, savefig=False, layers=None):
158166 elif backend == "plotly" :
159167 self .plot_plotly (show = show , savefig = savefig )
160168
161- def plot_matplotlib (self , show = True , savefig = False , layers = None ):
169+ def plot_matplotlib (self , show = True , savefig = False , layers = None , usetex = False ):
162170 """
163171 Generate and optionally display the subplots.
164172
165173 Parameters:
166174 filename (str, optional): Filename to save the figure.
167175 show (bool): Whether to display the plot.
168176 """
169-
170- tex_fonts = plt_utils .setup_tex_fonts (fontsize = self .fontsize )
177+ tex_fonts = plt_utils .setup_tex_fonts (fontsize = self .fontsize , usetex = usetex )
171178
172179 plt_utils .setup_plotstyle (
173180 tex_fonts = tex_fonts ,
@@ -208,9 +215,12 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None):
208215 plt .show ()
209216 # else:
210217 # plt.close()
218+ self ._plotted = True
219+ self ._matplotlib_fig = fig
220+ self ._matplotlib_axes = axes
211221 return fig , axes
212222
213- def plot_plotly (self , show = True , savefig = None ):
223+ def plot_plotly (self , show = True , savefig = None , usetex = False ):
214224 """
215225 Generate and optionally display the subplots using Plotly.
216226
@@ -220,7 +230,8 @@ def plot_plotly(self, show=True, savefig=None):
220230 """
221231
222232 tex_fonts = plt_utils .setup_tex_fonts (
223- fontsize = self .fontsize
233+ fontsize = self .fontsize ,
234+ usetex = usetex ,
224235 ) # adjust or redefine for Plotly if needed
225236
226237 # Set default width and height if not specified
0 commit comments