Skip to content

Commit 18c78d5

Browse files
committed
Added plotly backend
1 parent 8aeb8f0 commit 18c78d5

File tree

9 files changed

+401
-182
lines changed

9 files changed

+401
-182
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,4 @@ cython_debug/
183183
env*
184184

185185
# VS code
186-
.vscode/
186+
.vscode/*

.vscode/settings.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ classifiers = [
1616
]
1717
dependencies = [
1818
"matplotlib",
19+
"plotly",
1920
"pytest",
2021
"black",
2122
"isort",
23+
"jupyterlab",
2224
]
2325

2426
[project.optional-dependencies]

src/maxplotlib/backends/matplotlib/utils.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -108,63 +108,4 @@ def create_lineplot(
108108
constrained_layout=False,
109109
gridspec_kw=gridspec_kw,
110110
)
111-
return fig, axs
112-
113-
def create_3dplot(width=426.79135, dpi=300, ratio="golden"):
114-
"""
115-
Creates a 3D plot figure and axis.
116-
"""
117-
fig_width, fig_height = set_size(width, ratio=ratio)
118-
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
119-
ax = fig.add_subplot(111, projection="3d")
120-
return fig, ax
121-
122-
def set_common_xlabel(fig, xlabel="common X", fontsize=14):
123-
"""
124-
Sets a common X label for the figure.
125-
"""
126-
fig.text(0.5, -0.075, xlabel, va="center", ha="center", fontsize=fontsize)
127-
128-
def get_axis(axs, subfigure):
129-
"""
130-
Retrieves the specified axis from the axes array.
131-
"""
132-
if subfigure == -1:
133-
return axs
134-
elif not isinstance(subfigure, list):
135-
return axs[subfigure]
136-
elif isinstance(subfigure, list) and len(subfigure) == 2:
137-
return axs[subfigure[0], subfigure[1]]
138-
else:
139-
raise ValueError("Invalid subfigure index.")
140-
141-
def get_limits(ax=None):
142-
"""
143-
Gets the current axis limits.
144-
"""
145-
if ax is None:
146-
ax = plt.gca()
147-
xxmin, xxmax = ax.get_xlim()
148-
yymin, yymax = ax.get_ylim()
149-
return [xxmin, xxmax, yymin, yymax]
150-
151-
def set_labels(ax, delta, point, axis="x"):
152-
"""
153-
Sets custom tick labels on the specified axis.
154-
"""
155-
if axis == "x":
156-
xmin, xmax = ax.get_xlim()
157-
width = int((xmax - xmin) / delta + 1) * delta
158-
xvec = np.arange(point - width, point + width + delta, delta)
159-
xvec = xvec[(xvec >= xmin) & (xvec <= xmax)]
160-
ax.set_xticks(xvec)
161-
ax.set_xticklabels([f"{x:.2f}" for x in xvec])
162-
elif axis == "y":
163-
ymin, ymax = ax.get_ylim()
164-
width = int((ymax - ymin) / delta + 1) * delta
165-
yvec = np.arange(point - width, point + width + delta, delta)
166-
yvec = yvec[(yvec >= ymin) & (yvec <= ymax)]
167-
ax.set_yticks(yvec)
168-
ax.set_yticklabels([f"{y:.2f}" for y in yvec])
169-
else:
170-
raise ValueError("Axis must be 'x' or 'y'.")
111+
return fig, axs

src/maxplotlib/backends/plotly/__init__.py

Whitespace-only changes.

src/maxplotlib/backends/plotly/utils.py

Whitespace-only changes.

src/maxplotlib/canvas/canvas.py

Lines changed: 147 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import matplotlib.pyplot as plt
22
import maxplotlib.subfigure.line_plot as lp
33
import maxplotlib.backends.matplotlib.utils as plt_utils
4+
import plotly.graph_objects as go
5+
from plotly.subplots import make_subplots
46
class Canvas:
5-
def __init__(self, nrows=1, ncols=1, caption=None, description=None, label=None, figsize=(10, 6)):
7+
def __init__(self, **kwargs):
68
"""
79
Initialize the Canvas class for multiple subplots.
810
@@ -11,94 +13,25 @@ def __init__(self, nrows=1, ncols=1, caption=None, description=None, label=None,
1113
ncols (int): Number of subplot columns. Default is 1.
1214
figsize (tuple): Figure size.
1315
"""
14-
self._nrows = nrows
15-
self._ncols = ncols
16-
self._figsize = figsize
17-
self._caption = caption
18-
self._description = description
19-
self._label = label
16+
17+
#nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
18+
self._nrows = kwargs.get('nrows', 1)
19+
self._ncols = kwargs.get('ncols', 1)
20+
self._figsize = kwargs.get('figsize', None)
21+
self._caption = kwargs.get('caption', None)
22+
self._description = kwargs.get('description', None)
23+
self._label = kwargs.get('label', None)
24+
25+
self._width=kwargs.get('width', 426.79135)
26+
self._ratio=kwargs.get('ratio', "golden")
27+
self._gridspec_kw = kwargs.get('gridspec_kw', {"wspace": 0.08, "hspace": 0.1})
2028

2129
# Dictionary to store lines for each subplot
2230
# Key: (row, col), Value: list of lines with their data and kwargs
2331
self.subplots = {}
2432
self._num_subplots = 0
2533

26-
self._subplot_matrix = [[None] * ncols for _ in range(nrows)]
27-
28-
# Property getters
29-
@property
30-
def nrows(self):
31-
return self._nrows
32-
33-
@property
34-
def ncols(self):
35-
return self._ncols
36-
37-
@property
38-
def caption(self):
39-
return self._caption
40-
41-
@property
42-
def description(self):
43-
return self._description
44-
45-
@property
46-
def label(self):
47-
return self._label
48-
49-
@property
50-
def figsize(self):
51-
return self._figsize
52-
53-
@property
54-
def subplot_matrix(self):
55-
return self._subplot_matrix
56-
57-
# Property setters
58-
@nrows.setter
59-
def nrows(self, value):
60-
self._nrows = value
61-
62-
@ncols.setter
63-
def ncols(self, value):
64-
self._ncols = value
65-
66-
@caption.setter
67-
def caption(self, value):
68-
self._caption = value
69-
70-
@description.setter
71-
def description(self, value):
72-
self._description = value
73-
74-
@label.setter
75-
def label(self, value):
76-
self._label = value
77-
78-
@figsize.setter
79-
def figsize(self, value):
80-
self._figsize = value
81-
82-
# Magic methods
83-
def __str__(self):
84-
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"
85-
86-
def __repr__(self):
87-
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"
88-
89-
def __getitem__(self, key):
90-
"""Allows accessing subplots by tuple index."""
91-
row, col = key
92-
if row >= self.nrows or col >= self.ncols:
93-
raise IndexError("Subplot index out of range")
94-
return self._subplot_matrix[row][col]
95-
96-
def __setitem__(self, key, value):
97-
"""Allows setting a subplot by tuple index."""
98-
row, col = key
99-
if row >= self.nrows or col >= self.ncols:
100-
raise IndexError("Subplot index out of range")
101-
self._subplot_matrix[row][col] = value
34+
self._subplot_matrix = [[None] * self.ncols for _ in range(self.nrows)]
10235

10336
def add_subplot(self, **kwargs):
10437
"""
@@ -141,14 +74,15 @@ def add_subplot(self, **kwargs):
14174
return line_plot
14275
def savefig(self, filename, backend = 'matplotlib'):
14376
if backend == 'matplotlib':
144-
fig = self.plot(show=False, savefig=True)
77+
fig, axs = self.plot(show=False, backend='matplotlib', savefig=True)
14578
fig.savefig(filename)
146-
#plt.savefig(filename)
14779
# def add_line(self, label, x_data, y_data, **kwargs):
14880

14981
def plot(self, backend='matplotlib', show=True, savefig=False):
15082
if backend == 'matplotlib':
15183
return self.plot_matplotlib(show=show, savefig=savefig)
84+
elif backend == 'plotly':
85+
self.plot_plotly(show=show, savefig=savefig)
15286
def plot_matplotlib(self, show=True, savefig=False):
15387
"""
15488
Generate and optionally display the subplots.
@@ -166,7 +100,13 @@ def plot_matplotlib(self, show=True, savefig=False):
166100
grid_alpha=1.0,
167101
grid_linestyle="dotted",
168102
)
169-
fig, axes = plt.subplots(self.nrows, self.ncols, figsize=self.figsize, squeeze=False)
103+
104+
if self._figsize is not None:
105+
fig_width, fig_height = self._figsize
106+
else:
107+
fig_width, fig_height = plt_utils.set_size(width=self._width, ratio=self._ratio)
108+
109+
fig, axes = plt.subplots(self.nrows, self.ncols, figsize=(fig_width, fig_height), squeeze=False)
170110

171111
for (row, col), line_plot in self.subplots.items():
172112
ax = axes[row][col]
@@ -180,8 +120,129 @@ def plot_matplotlib(self, show=True, savefig=False):
180120
plt.show()
181121
else:
182122
plt.close()
123+
return fig, axes
124+
125+
def plot_plotly(self, show=True, savefig=None):
126+
"""
127+
Generate and optionally display the subplots using Plotly.
128+
129+
Parameters:
130+
show (bool): Whether to display the plot.
131+
savefig (str, optional): Filename to save the figure if provided.
132+
"""
133+
fontsize = 14
134+
tex_fonts = plt_utils.setup_tex_fonts(fontsize=fontsize) # adjust or redefine for Plotly if needed
135+
136+
# Set default width and height if not specified
137+
if self._figsize is not None:
138+
fig_width, fig_height = self._figsize
139+
else:
140+
fig_width, fig_height = plt_utils.set_size(width=self._width, ratio=self._ratio)
141+
142+
# Create subplots
143+
fig = make_subplots(rows=self.nrows, cols=self.ncols, subplot_titles=[
144+
f"Subplot ({row}, {col})" for (row, col) in self.subplots.keys()
145+
])
146+
147+
# Plot each subplot
148+
for (row, col), line_plot in self.subplots.items():
149+
traces = line_plot.plot_plotly() # Generate Plotly traces for the line_plot
150+
for trace in traces:
151+
fig.add_trace(trace, row=row + 1, col=col + 1)
152+
153+
# Update layout settings
154+
fig.update_layout(
155+
width=fig_width,
156+
height=fig_height,
157+
font=dict(size=fontsize),
158+
margin=dict(l=10, r=10, t=40, b=10), # Adjust margins if needed
159+
)
160+
161+
# Optionally save the figure
162+
if savefig:
163+
fig.write_image(savefig)
164+
165+
# Show or return the figure
166+
if show:
167+
fig.show()
183168
return fig
184169

170+
171+
# Property getters
172+
@property
173+
def nrows(self):
174+
return self._nrows
175+
176+
@property
177+
def ncols(self):
178+
return self._ncols
179+
180+
@property
181+
def caption(self):
182+
return self._caption
183+
184+
@property
185+
def description(self):
186+
return self._description
187+
188+
@property
189+
def label(self):
190+
return self._label
191+
192+
@property
193+
def figsize(self):
194+
return self._figsize
195+
196+
@property
197+
def subplot_matrix(self):
198+
return self._subplot_matrix
199+
200+
# Property setters
201+
@nrows.setter
202+
def nrows(self, value):
203+
self._nrows = value
204+
205+
@ncols.setter
206+
def ncols(self, value):
207+
self._ncols = value
208+
209+
@caption.setter
210+
def caption(self, value):
211+
self._caption = value
212+
213+
@description.setter
214+
def description(self, value):
215+
self._description = value
216+
217+
@label.setter
218+
def label(self, value):
219+
self._label = value
220+
221+
@figsize.setter
222+
def figsize(self, value):
223+
self._figsize = value
224+
225+
# Magic methods
226+
def __str__(self):
227+
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"
228+
229+
def __repr__(self):
230+
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"
231+
232+
def __getitem__(self, key):
233+
"""Allows accessing subplots by tuple index."""
234+
row, col = key
235+
if row >= self.nrows or col >= self.ncols:
236+
raise IndexError("Subplot index out of range")
237+
return self._subplot_matrix[row][col]
238+
239+
def __setitem__(self, key, value):
240+
"""Allows setting a subplot by tuple index."""
241+
row, col = key
242+
if row >= self.nrows or col >= self.ncols:
243+
raise IndexError("Subplot index out of range")
244+
self._subplot_matrix[row][col] = value
245+
185246
# def generate_matplotlib_code(self):
186247
# """Generate code for plotting the data using matplotlib."""
187248
# code = "import matplotlib.pyplot as plt\n\n"

0 commit comments

Comments
 (0)