Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from maxplotlib.canvas.canvas import Canvas

__all__ = ["Canvas"]
97 changes: 70 additions & 27 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from plotly.subplots import make_subplots

import maxplotlib.backends.matplotlib.utils as plt_utils
import maxplotlib.subfigure.line_plot as lp
import maxplotlib.subfigure.tikz_figure as tf
from maxplotlib.subfigure.line_plot import LinePlot
from maxplotlib.subfigure.tikz_figure import TikzFigure


class Canvas:
Expand Down Expand Up @@ -37,11 +37,15 @@ def __init__(self, **kwargs):

# Dictionary to store lines for each subplot
# Key: (row, col), Value: list of lines with their data and kwargs
self.subplots = {}
self._subplots = {}
self._num_subplots = 0

self._subplot_matrix = [[None] * self.ncols for _ in range(self.nrows)]

@property
def subplots(self):
return self._subplots

@property
def layers(self):
layers = []
Expand All @@ -66,6 +70,38 @@ def generate_new_rowcol(self, row, col):
assert col is not None, "Not enough columns!"
return row, col

def add_line(
self,
x_data,
y_data,
layer=0,
subplot: LinePlot | None = None,
row: int | None = None,
col: int | None = None,
plot_type="plot",
**kwargs,
):
if row is not None and col is not None:
try:
subplot = self._subplot_matrix[row][col]
except KeyError:
raise ValueError("Invalid subplot position.")
else:
row, col = 0, 0
subplot = self._subplot_matrix[row][col]

if subplot is None:
row, col = self.generate_new_rowcol(row, col)
subplot = self.add_subplot(col=col, row=row)

subplot.add_line(
x_data=x_data,
y_data=y_data,
layer=layer,
plot_type=plot_type,
**kwargs,
)

def add_tikzfigure(self, **kwargs):
"""
Adds a subplot to the figure.
Expand All @@ -83,17 +119,23 @@ def add_tikzfigure(self, **kwargs):
row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
tikz_figure = tf.TikzFigure(**kwargs)
tikz_figure = TikzFigure(**kwargs)
self._subplot_matrix[row][col] = tikz_figure

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = tikz_figure
self._subplots[(row, col)] = tikz_figure
else:
self.subplots[label] = tikz_figure
self._subplots[label] = tikz_figure
return tikz_figure

def add_subplot(self, **kwargs):
def add_subplot(
self,
col: int | None = None,
row: int | None = None,
label: str | None = None,
**kwargs,
):
"""
Adds a subplot to the figure.

Expand All @@ -103,21 +145,18 @@ def add_subplot(self, **kwargs):
- row (int): Row index for the subplot.
- label (str): Label to identify the subplot.
"""
col = kwargs.get("col", None)
row = kwargs.get("row", None)
label = kwargs.get("label", None)

row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
line_plot = lp.LinePlot(**kwargs)
line_plot = LinePlot(col=col, row=row, label=label, **kwargs)
self._subplot_matrix[row][col] = line_plot

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = line_plot
self._subplots[(row, col)] = line_plot
else:
self.subplots[label] = line_plot
self._subplots[label] = line_plot
return line_plot

def savefig(
Expand Down Expand Up @@ -159,19 +198,29 @@ def savefig(
if verbose:
print(f"Saved {full_filepath}")

def plot(self, backend="matplotlib", show=True, savefig=False, layers=None):
def plot(self, backend="matplotlib", savefig=False, layers=None):
if backend == "matplotlib":
return self.plot_matplotlib(savefig=savefig, layers=layers)
elif backend == "plotly":
return self.plot_plotly(savefig=savefig)
else:
raise ValueError("Invalid backend")

def show(self, backend="matplotlib"):
if backend == "matplotlib":
return self.plot_matplotlib(show=show, savefig=savefig, layers=layers)
self.plot(backend="matplotlib", savefig=False, layers=None)
self._matplotlib_fig.show()
elif backend == "plotly":
self.plot_plotly(show=show, savefig=savefig)
plot = self.plot_plotly(savefig=False)
else:
raise ValueError("Invalid backend")

def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
"""
Generate and optionally display the subplots.

Parameters:
filename (str, optional): Filename to save the figure.
show (bool): Whether to display the plot.
"""

tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)
Expand Down Expand Up @@ -205,17 +254,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):

for (row, col), subplot in self.subplots.items():
ax = axes[row][col]
# print(f"{subplot = }")
subplot.plot_matplotlib(ax, layers=layers)
# ax.set_title(f"Subplot ({row}, {col})")
ax.grid()
# Set caption, labels, etc., if needed
# plt.tight_layout()

if show:
plt.show()
# else:
# plt.close()
# Set caption, labels, etc., if needed
self._plotted = True
self._matplotlib_fig = fig
self._matplotlib_axes = axes
Expand Down Expand Up @@ -271,8 +314,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
fig.write_image(savefig)

# Show or return the figure
if show:
fig.show()
# if show:
# fig.show()
return fig

# Property getters
Expand Down
9 changes: 8 additions & 1 deletion src/maxplotlib/subfigure/line_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,14 @@ def _add(self, obj, layer):
else:
self.layered_line_data[layer] = [obj]

def add_line(self, x_data, y_data, layer=0, plot_type="plot", **kwargs):
def add_line(
self,
x_data,
y_data,
layer=0,
plot_type="plot",
**kwargs,
):
"""
Add a line to the plot.

Expand Down
43 changes: 28 additions & 15 deletions tutorials/tutorial_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"metadata": {},
"outputs": [],
"source": [
"import maxplotlib.canvas.canvas as canvas\n",
"from maxplotlib import Canvas\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
Expand All @@ -29,53 +29,66 @@
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(width=\"17cm\", ratio=0.5, fontsize=12)\n",
"c = Canvas(width=\"17cm\", ratio=0.5, fontsize=12)\n",
"c.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\")\n",
"c.add_line([0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\")\n",
"c.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"# You can also explicitly create a subplot and add lines to it\n",
"\n",
"c = Canvas(width=\"17cm\", ratio=0.5, fontsize=12)\n",
"sp = c.add_subplot(\n",
" grid=True, xlabel=\"(x - 10) * 0.1\", ylabel=\"10y\", yscale=10, xshift=-10, xscale=0.1\n",
")\n",
"\n",
"sp.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\")\n",
"sp.add_line([0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\")\n",
"c.plot()\n",
"c.savefig(filename=\"tutorial_01_01.pdf\")"
"c.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(width=\"17cm\", ncols=2, nrows=2, ratio=0.5)\n",
"# Example with multiple subplots\n",
"\n",
"c = Canvas(width=\"17cm\", ncols=2, nrows=2, ratio=0.5)\n",
"sp = c.add_subplot(grid=True)\n",
"c.add_subplot(row=1)\n",
"sp2 = c.add_subplot(row=1, legend=False)\n",
"sp.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\")\n",
"sp2.add_line(\n",
" [0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\"\n",
")\n",
"c.plot(backend=\"matplotlib\")\n",
"c.plot(backend=\"plotly\")\n",
"c.savefig(filename=\"tutorial_01_02.pdf\")"
"c.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"# Test with plotly backend\n",
"c = canvas.Canvas(width=\"17cm\", ratio=0.5)\n",
"c = Canvas(width=\"17cm\", ratio=0.5)\n",
"sp = c.add_subplot(\n",
" grid=True, xlabel=\"x (mm)\", ylabel=\"10y\", yscale=10, xshift=-10, xscale=0.1\n",
")\n",
"sp.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\", linestyle=\"-.\")\n",
"sp.add_line([0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\")\n",
"c.plot(backend=\"matplotlib\")\n",
"c.plot(backend=\"plotly\")\n",
"c.savefig(filename=\"tutorial_01_03.pdf\")"
"c.show()"
]
}
],
Expand All @@ -95,7 +108,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions tutorials/tutorial_02.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"import maxplotlib.canvas.canvas as canvas\n",
"from maxplotlib import Canvas\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
Expand All @@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(width=800, ratio=0.5)\n",
"c = Canvas(width=800, ratio=0.5)\n",
"tikz = c.add_tikzfigure(grid=False)\n",
"\n",
"# Add nodes\n",
Expand Down Expand Up @@ -61,7 +61,7 @@
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(ncols=2, width=\"20cm\", ratio=0.5)\n",
"c = Canvas(ncols=2, width=\"20cm\", ratio=0.5)\n",
"tikz = c.add_tikzfigure(grid=False)\n",
"\n",
"# Add nodes\n",
Expand Down Expand Up @@ -111,7 +111,7 @@
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(width=800, ratio=0.5)\n",
"c = Canvas(width=800, ratio=0.5)\n",
"tikz = c.add_tikzfigure(grid=False)\n",
"\n",
"# Add nodes\n",
Expand Down Expand Up @@ -146,7 +146,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
9 changes: 4 additions & 5 deletions tutorials/tutorial_03.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"import maxplotlib.canvas.canvas as canvas"
"from maxplotlib import Canvas"
]
},
{
Expand All @@ -25,16 +25,15 @@
"metadata": {},
"outputs": [],
"source": [
"c = canvas.Canvas(width=\"17cm\", ratio=0.5)\n",
"c = Canvas(width=\"17cm\", ratio=0.5)\n",
"sp = c.add_subplot(\n",
" grid=False, xlabel=\"(x - 10) * 0.1\", ylabel=\"10y\", yscale=10, xshift=-10, xscale=0.1\n",
")\n",
"sp.add_line([0, 1, 2, 3], [0, 1, 4, 9], label=\"Line 1\", layer=0)\n",
"sp.add_line(\n",
" [0, 1, 2, 3], [0, 2, 3, 4], linestyle=\"dashed\", color=\"red\", label=\"Line 2\", layer=1\n",
")\n",
"# c.plot()\n",
"c.savefig(layer_by_layer=True, filename=\"tutorial_03_01.pdf\")"
"c.show()"
]
}
],
Expand All @@ -59,7 +58,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.3"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Loading
Loading