11import matplotlib .pyplot as plt
2- import maxplotlib .subfigure .line_plot as lp
3- import maxplotlib .subfigure .tikz_figure as tf
4- import maxplotlib .backends .matplotlib .utils as plt_utils
52import plotly .graph_objects as go
63from plotly .subplots import make_subplots
4+
5+ import maxplotlib .backends .matplotlib .utils as plt_utils
6+ import maxplotlib .subfigure .line_plot as lp
7+ import maxplotlib .subfigure .tikz_figure as tf
8+
9+
710class Canvas :
811 def __init__ (self , ** kwargs ):
912 """
@@ -15,27 +18,27 @@ def __init__(self, **kwargs):
1518 figsize (tuple): Figure size.
1619 """
1720
18- #nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
19- self ._nrows = kwargs .get (' nrows' , 1 )
20- self ._ncols = kwargs .get (' ncols' , 1 )
21- self ._figsize = kwargs .get (' figsize' , None )
22- self ._caption = kwargs .get (' caption' , None )
23- self ._description = kwargs .get (' description' , None )
24- self ._label = kwargs .get (' label' , None )
25-
26- self ._dpi = kwargs .get (' dpi' , 300 )
27- self ._width = kwargs .get (' width' , 426.79135 )
28- self ._ratio = kwargs .get (' ratio' , "golden" )
29- self ._gridspec_kw = kwargs .get (' gridspec_kw' , {"wspace" : 0.08 , "hspace" : 0.1 })
30-
21+ # nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
22+ self ._nrows = kwargs .get (" nrows" , 1 )
23+ self ._ncols = kwargs .get (" ncols" , 1 )
24+ self ._figsize = kwargs .get (" figsize" , None )
25+ self ._caption = kwargs .get (" caption" , None )
26+ self ._description = kwargs .get (" description" , None )
27+ self ._label = kwargs .get (" label" , None )
28+
29+ self ._dpi = kwargs .get (" dpi" , 300 )
30+ self ._width = kwargs .get (" width" , 426.79135 )
31+ self ._ratio = kwargs .get (" ratio" , "golden" )
32+ self ._gridspec_kw = kwargs .get (" gridspec_kw" , {"wspace" : 0.08 , "hspace" : 0.1 })
33+
3134 # Dictionary to store lines for each subplot
3235 # Key: (row, col), Value: list of lines with their data and kwargs
3336 self .subplots = {}
3437 self ._num_subplots = 0
3538
3639 self ._subplot_matrix = [[None ] * self .ncols for _ in range (self .nrows )]
37-
38- def generate_new_rowcol (self ,row ,col ):
40+
41+ def generate_new_rowcol (self , row , col ):
3942 if row is None :
4043 for irow in range (self .nrows ):
4144 has_none = any (item is None for item in self ._subplot_matrix [irow ])
@@ -62,9 +65,9 @@ def add_tikzfigure(self, **kwargs):
6265 - row (int): Row index for the subplot.
6366 - label (str): Label to identify the subplot.
6467 """
65- col = kwargs .get (' col' , None )
66- row = kwargs .get (' row' , None )
67- label = kwargs .get (' label' , None )
68+ col = kwargs .get (" col" , None )
69+ row = kwargs .get (" row" , None )
70+ label = kwargs .get (" label" , None )
6871
6972 row , col = self .generate_new_rowcol (row , col )
7073
@@ -89,12 +92,12 @@ def add_subplot(self, **kwargs):
8992 - row (int): Row index for the subplot.
9093 - label (str): Label to identify the subplot.
9194 """
92- col = kwargs .get (' col' , None )
93- row = kwargs .get (' row' , None )
94- label = kwargs .get (' label' , None )
95+ col = kwargs .get (" col" , None )
96+ row = kwargs .get (" row" , None )
97+ label = kwargs .get (" label" , None )
9598
9699 row , col = self .generate_new_rowcol (row , col )
97-
100+
98101 # Initialize the LinePlot for the given subplot position
99102 line_plot = lp .LinePlot (** kwargs )
100103 self ._subplot_matrix [row ][col ] = line_plot
@@ -105,17 +108,20 @@ def add_subplot(self, **kwargs):
105108 else :
106109 self .subplots [label ] = line_plot
107110 return line_plot
108- def savefig (self , filename , backend = 'matplotlib' ):
109- if backend == 'matplotlib' :
110- fig , axs = self .plot (show = False , backend = 'matplotlib' , savefig = True )
111+
112+ def savefig (self , filename , backend = "matplotlib" ):
113+ if backend == "matplotlib" :
114+ fig , axs = self .plot (show = False , backend = "matplotlib" , savefig = True )
111115 fig .savefig (filename )
116+
112117 # def add_line(self, label, x_data, y_data, **kwargs):
113118
114- def plot (self , backend = ' matplotlib' , show = True , savefig = False ):
115- if backend == ' matplotlib' :
119+ def plot (self , backend = " matplotlib" , show = True , savefig = False ):
120+ if backend == " matplotlib" :
116121 return self .plot_matplotlib (show = show , savefig = savefig )
117- elif backend == ' plotly' :
122+ elif backend == " plotly" :
118123 self .plot_plotly (show = show , savefig = savefig )
124+
119125 def plot_matplotlib (self , show = True , savefig = False ):
120126 """
121127 Generate and optionally display the subplots.
@@ -137,18 +143,26 @@ def plot_matplotlib(self, show=True, savefig=False):
137143 if self ._figsize is not None :
138144 fig_width , fig_height = self ._figsize
139145 else :
140- fig_width , fig_height = plt_utils .set_size (width = self ._width , ratio = self ._ratio )
141-
142- fig , axes = plt .subplots (self .nrows , self .ncols , figsize = (fig_width / self ._dpi , fig_height / self ._dpi ), squeeze = False , dpi = self ._dpi )
143-
146+ fig_width , fig_height = plt_utils .set_size (
147+ width = self ._width , ratio = self ._ratio
148+ )
149+
150+ fig , axes = plt .subplots (
151+ self .nrows ,
152+ self .ncols ,
153+ figsize = (fig_width / self ._dpi , fig_height / self ._dpi ),
154+ squeeze = False ,
155+ dpi = self ._dpi ,
156+ )
157+
144158 for (row , col ), subplot in self .subplots .items ():
145159 ax = axes [row ][col ]
146160 subplot .plot_matplotlib (ax )
147161 # ax.set_title(f"Subplot ({row}, {col})")
148162
149163 # Set caption, labels, etc., if needed
150164 plt .tight_layout ()
151-
165+
152166 if show :
153167 plt .show ()
154168 else :
@@ -164,18 +178,26 @@ def plot_plotly(self, show=True, savefig=None):
164178 savefig (str, optional): Filename to save the figure if provided.
165179 """
166180 fontsize = 14
167- tex_fonts = plt_utils .setup_tex_fonts (fontsize = fontsize ) # adjust or redefine for Plotly if needed
181+ tex_fonts = plt_utils .setup_tex_fonts (
182+ fontsize = fontsize
183+ ) # adjust or redefine for Plotly if needed
168184
169185 # Set default width and height if not specified
170186 if self ._figsize is not None :
171187 fig_width , fig_height = self ._figsize
172188 else :
173- fig_width , fig_height = plt_utils .set_size (width = self ._width , ratio = self ._ratio )
189+ fig_width , fig_height = plt_utils .set_size (
190+ width = self ._width , ratio = self ._ratio
191+ )
174192 print (self ._width , fig_width , fig_height )
175193 # Create subplots
176- fig = make_subplots (rows = self .nrows , cols = self .ncols , subplot_titles = [
177- f"Subplot ({ row } , { col } )" for (row , col ) in self .subplots .keys ()
178- ])
194+ fig = make_subplots (
195+ rows = self .nrows ,
196+ cols = self .ncols ,
197+ subplot_titles = [
198+ f"Subplot ({ row } , { col } )" for (row , col ) in self .subplots .keys ()
199+ ],
200+ )
179201
180202 # Plot each subplot
181203 for (row , col ), line_plot in self .subplots .items ():
@@ -200,7 +222,6 @@ def plot_plotly(self, show=True, savefig=None):
200222 fig .show ()
201223 return fig
202224
203-
204225 # Property getters
205226
206227 @property
@@ -239,7 +260,7 @@ def subplot_matrix(self):
239260 @nrows .setter
240261 def dpi (self , value ):
241262 self ._dpi = value
242-
263+
243264 @nrows .setter
244265 def nrows (self , value ):
245266 self ._nrows = value
@@ -354,11 +375,11 @@ def __setitem__(self, key, value):
354375 # latex_code += "\\caption{Multiple Subplots}\n"
355376 # latex_code += "\\end{figure}\n"
356377 # return latex_code
357-
358378
359- if __name__ == '__main__' :
360- c = Canvas (ncols = 2 ,nrows = 2 )
379+
380+ if __name__ == "__main__" :
381+ c = Canvas (ncols = 2 , nrows = 2 )
361382 sp = c .add_subplot ()
362383 sp .add_line ("Line 1" , [0 , 1 , 2 , 3 ], [0 , 1 , 4 , 9 ])
363384 c .plot ()
364- print (' done' )
385+ print (" done" )
0 commit comments