11import matplotlib .pyplot as plt
22import maxplotlib .subfigure .line_plot as lp
33import maxplotlib .backends .matplotlib .utils as plt_utils
4+ import plotly .graph_objects as go
5+ from plotly .subplots import make_subplots
46class Canvas :
5- def __init__ (self , nrows = 1 , ncols = 1 , caption = None , description = None , label = None , figsize = ( 10 , 6 ) ):
7+ def __init__ (self , ** kwargs ):
68 """
79 Initialize the Canvas class for multiple subplots.
810
@@ -11,94 +13,25 @@ def __init__(self, nrows=1, ncols=1, caption=None, description=None, label=None,
1113 ncols (int): Number of subplot columns. Default is 1.
1214 figsize (tuple): Figure size.
1315 """
14- self ._nrows = nrows
15- self ._ncols = ncols
16- self ._figsize = figsize
17- self ._caption = caption
18- self ._description = description
19- self ._label = label
16+
17+ #nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
18+ self ._nrows = kwargs .get ('nrows' , 1 )
19+ self ._ncols = kwargs .get ('ncols' , 1 )
20+ self ._figsize = kwargs .get ('figsize' , None )
21+ self ._caption = kwargs .get ('caption' , None )
22+ self ._description = kwargs .get ('description' , None )
23+ self ._label = kwargs .get ('label' , None )
24+
25+ self ._width = kwargs .get ('width' , 426.79135 )
26+ self ._ratio = kwargs .get ('ratio' , "golden" )
27+ self ._gridspec_kw = kwargs .get ('gridspec_kw' , {"wspace" : 0.08 , "hspace" : 0.1 })
2028
2129 # Dictionary to store lines for each subplot
2230 # Key: (row, col), Value: list of lines with their data and kwargs
2331 self .subplots = {}
2432 self ._num_subplots = 0
2533
26- self ._subplot_matrix = [[None ] * ncols for _ in range (nrows )]
27-
28- # Property getters
29- @property
30- def nrows (self ):
31- return self ._nrows
32-
33- @property
34- def ncols (self ):
35- return self ._ncols
36-
37- @property
38- def caption (self ):
39- return self ._caption
40-
41- @property
42- def description (self ):
43- return self ._description
44-
45- @property
46- def label (self ):
47- return self ._label
48-
49- @property
50- def figsize (self ):
51- return self ._figsize
52-
53- @property
54- def subplot_matrix (self ):
55- return self ._subplot_matrix
56-
57- # Property setters
58- @nrows .setter
59- def nrows (self , value ):
60- self ._nrows = value
61-
62- @ncols .setter
63- def ncols (self , value ):
64- self ._ncols = value
65-
66- @caption .setter
67- def caption (self , value ):
68- self ._caption = value
69-
70- @description .setter
71- def description (self , value ):
72- self ._description = value
73-
74- @label .setter
75- def label (self , value ):
76- self ._label = value
77-
78- @figsize .setter
79- def figsize (self , value ):
80- self ._figsize = value
81-
82- # Magic methods
83- def __str__ (self ):
84- return f"Canvas(nrows={ self .nrows } , ncols={ self .ncols } , figsize={ self .figsize } )"
85-
86- def __repr__ (self ):
87- return f"Canvas(nrows={ self .nrows } , ncols={ self .ncols } , caption={ self .caption } , label={ self .label } )"
88-
89- def __getitem__ (self , key ):
90- """Allows accessing subplots by tuple index."""
91- row , col = key
92- if row >= self .nrows or col >= self .ncols :
93- raise IndexError ("Subplot index out of range" )
94- return self ._subplot_matrix [row ][col ]
95-
96- def __setitem__ (self , key , value ):
97- """Allows setting a subplot by tuple index."""
98- row , col = key
99- if row >= self .nrows or col >= self .ncols :
100- raise IndexError ("Subplot index out of range" )
101- self ._subplot_matrix [row ][col ] = value
34+ self ._subplot_matrix = [[None ] * self .ncols for _ in range (self .nrows )]
10235
10336 def add_subplot (self , ** kwargs ):
10437 """
@@ -141,14 +74,15 @@ def add_subplot(self, **kwargs):
14174 return line_plot
14275 def savefig (self , filename , backend = 'matplotlib' ):
14376 if backend == 'matplotlib' :
144- fig = self .plot (show = False , savefig = True )
77+ fig , axs = self .plot (show = False , backend = 'matplotlib' , savefig = True )
14578 fig .savefig (filename )
146- #plt.savefig(filename)
14779 # def add_line(self, label, x_data, y_data, **kwargs):
14880
14981 def plot (self , backend = 'matplotlib' , show = True , savefig = False ):
15082 if backend == 'matplotlib' :
15183 return self .plot_matplotlib (show = show , savefig = savefig )
84+ elif backend == 'plotly' :
85+ self .plot_plotly (show = show , savefig = savefig )
15286 def plot_matplotlib (self , show = True , savefig = False ):
15387 """
15488 Generate and optionally display the subplots.
@@ -166,7 +100,13 @@ def plot_matplotlib(self, show=True, savefig=False):
166100 grid_alpha = 1.0 ,
167101 grid_linestyle = "dotted" ,
168102 )
169- fig , axes = plt .subplots (self .nrows , self .ncols , figsize = self .figsize , squeeze = False )
103+
104+ if self ._figsize is not None :
105+ fig_width , fig_height = self ._figsize
106+ else :
107+ fig_width , fig_height = plt_utils .set_size (width = self ._width , ratio = self ._ratio )
108+
109+ fig , axes = plt .subplots (self .nrows , self .ncols , figsize = (fig_width , fig_height ), squeeze = False )
170110
171111 for (row , col ), line_plot in self .subplots .items ():
172112 ax = axes [row ][col ]
@@ -180,8 +120,129 @@ def plot_matplotlib(self, show=True, savefig=False):
180120 plt .show ()
181121 else :
182122 plt .close ()
123+ return fig , axes
124+
125+ def plot_plotly (self , show = True , savefig = None ):
126+ """
127+ Generate and optionally display the subplots using Plotly.
128+
129+ Parameters:
130+ show (bool): Whether to display the plot.
131+ savefig (str, optional): Filename to save the figure if provided.
132+ """
133+ fontsize = 14
134+ tex_fonts = plt_utils .setup_tex_fonts (fontsize = fontsize ) # adjust or redefine for Plotly if needed
135+
136+ # Set default width and height if not specified
137+ if self ._figsize is not None :
138+ fig_width , fig_height = self ._figsize
139+ else :
140+ fig_width , fig_height = plt_utils .set_size (width = self ._width , ratio = self ._ratio )
141+
142+ # Create subplots
143+ fig = make_subplots (rows = self .nrows , cols = self .ncols , subplot_titles = [
144+ f"Subplot ({ row } , { col } )" for (row , col ) in self .subplots .keys ()
145+ ])
146+
147+ # Plot each subplot
148+ for (row , col ), line_plot in self .subplots .items ():
149+ traces = line_plot .plot_plotly () # Generate Plotly traces for the line_plot
150+ for trace in traces :
151+ fig .add_trace (trace , row = row + 1 , col = col + 1 )
152+
153+ # Update layout settings
154+ fig .update_layout (
155+ width = fig_width ,
156+ height = fig_height ,
157+ font = dict (size = fontsize ),
158+ margin = dict (l = 10 , r = 10 , t = 40 , b = 10 ), # Adjust margins if needed
159+ )
160+
161+ # Optionally save the figure
162+ if savefig :
163+ fig .write_image (savefig )
164+
165+ # Show or return the figure
166+ if show :
167+ fig .show ()
183168 return fig
184169
170+
171+ # Property getters
172+ @property
173+ def nrows (self ):
174+ return self ._nrows
175+
176+ @property
177+ def ncols (self ):
178+ return self ._ncols
179+
180+ @property
181+ def caption (self ):
182+ return self ._caption
183+
184+ @property
185+ def description (self ):
186+ return self ._description
187+
188+ @property
189+ def label (self ):
190+ return self ._label
191+
192+ @property
193+ def figsize (self ):
194+ return self ._figsize
195+
196+ @property
197+ def subplot_matrix (self ):
198+ return self ._subplot_matrix
199+
200+ # Property setters
201+ @nrows .setter
202+ def nrows (self , value ):
203+ self ._nrows = value
204+
205+ @ncols .setter
206+ def ncols (self , value ):
207+ self ._ncols = value
208+
209+ @caption .setter
210+ def caption (self , value ):
211+ self ._caption = value
212+
213+ @description .setter
214+ def description (self , value ):
215+ self ._description = value
216+
217+ @label .setter
218+ def label (self , value ):
219+ self ._label = value
220+
221+ @figsize .setter
222+ def figsize (self , value ):
223+ self ._figsize = value
224+
225+ # Magic methods
226+ def __str__ (self ):
227+ return f"Canvas(nrows={ self .nrows } , ncols={ self .ncols } , figsize={ self .figsize } )"
228+
229+ def __repr__ (self ):
230+ return f"Canvas(nrows={ self .nrows } , ncols={ self .ncols } , caption={ self .caption } , label={ self .label } )"
231+
232+ def __getitem__ (self , key ):
233+ """Allows accessing subplots by tuple index."""
234+ row , col = key
235+ if row >= self .nrows or col >= self .ncols :
236+ raise IndexError ("Subplot index out of range" )
237+ return self ._subplot_matrix [row ][col ]
238+
239+ def __setitem__ (self , key , value ):
240+ """Allows setting a subplot by tuple index."""
241+ row , col = key
242+ if row >= self .nrows or col >= self .ncols :
243+ raise IndexError ("Subplot index out of range" )
244+ self ._subplot_matrix [row ][col ] = value
245+
185246 # def generate_matplotlib_code(self):
186247 # """Generate code for plotting the data using matplotlib."""
187248 # code = "import matplotlib.pyplot as plt\n\n"
0 commit comments