Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from maxplotlib.canvas.canvas import Canvas

__all__ = ["Canvas"]
33 changes: 24 additions & 9 deletions src/maxplotlib/backends/matplotlib/utils_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,12 @@ def setup_plotstyle(

def set_common_xlabel(self, xlabel="common X"):
self.fig.text(
0.5, -0.075, xlabel, va="center", ha="center", fontsize=self.fontsize
0.5,
-0.075,
xlabel,
va="center",
ha="center",
fontsize=self.fontsize,
)
# fig.text(0.04, 0.5, 'common Y', va='center', ha='center', rotation='vertical', fontsize=rcParams['axes.labelsize'])

Expand Down Expand Up @@ -438,7 +443,9 @@ def scale_axis(
i0 = int(xmin / delta)
i1 = int(xmax / delta + 1)
locs = np.arange(
includepoint - width, includepoint + width + delta, delta
includepoint - width,
includepoint + width + delta,
delta,
)
locs = locs[locs >= xmin - 1e-12]
locs = locs[locs <= xmax + 1e-12]
Expand Down Expand Up @@ -473,7 +480,9 @@ def scale_axis(
i0 = int(ymin / delta)
i1 = int(ymax / delta + 1)
locs = np.arange(
includepoint - width, includepoint + width + delta, delta
includepoint - width,
includepoint + width + delta,
delta,
)
locs = locs[locs >= ymin - 1e-12]
locs = locs[locs <= ymax + 1e-12]
Expand Down Expand Up @@ -507,7 +516,10 @@ def adjustFigAspect(self, aspect=1):
else:
ylim /= aspect
self.fig.subplots_adjust(
left=0.5 - xlim, right=0.5 + xlim, bottom=0.5 - ylim, top=0.5 + ylim
left=0.5 - xlim,
right=0.5 + xlim,
bottom=0.5 - ylim,
top=0.5 + ylim,
)

def add_figure_label(
Expand Down Expand Up @@ -619,14 +631,16 @@ def savefig(
# self.fig.savefig(self.directory + filename + '.' + format,bbox_inches='tight', transparent=False)
if tight_layout:
self.fig.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)
else:
self.fig.savefig(self.directory + filename + "." + format)
elif format == "pgf":
# Save pgf figure
self.fig.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)

# Replace pgf figure colors with colorlet
Expand Down Expand Up @@ -672,15 +686,16 @@ def savefig(
else:
try:
plt.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)
except Exception as e:
print(
"ERROR: Could not save figure: "
+ self.directory
+ filename
+ "."
+ format
+ format,
)
print(e)

Expand All @@ -690,7 +705,7 @@ def savefig(
for format in formats:
if format in imgcat_formats:
f.write(
"imgcat " + self.directory + filename + "." + format + "\n"
"imgcat " + self.directory + filename + "." + format + "\n",
)

if print_imgcat and ("png" in formats or "pdf" in formats):
Expand Down
158 changes: 121 additions & 37 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from plotly.subplots import make_subplots

import maxplotlib.backends.matplotlib.utils as plt_utils
import maxplotlib.subfigure.line_plot as lp
import maxplotlib.subfigure.tikz_figure as tf
from maxplotlib.subfigure.line_plot import LinePlot
from maxplotlib.subfigure.tikz_figure import TikzFigure


class Canvas:
Expand Down Expand Up @@ -37,11 +37,15 @@ def __init__(self, **kwargs):

# Dictionary to store lines for each subplot
# Key: (row, col), Value: list of lines with their data and kwargs
self.subplots = {}
self._subplots = {}
self._num_subplots = 0

self._subplot_matrix = [[None] * self.ncols for _ in range(self.nrows)]

@property
def subplots(self):
return self._subplots

@property
def layers(self):
layers = []
Expand All @@ -66,34 +70,92 @@ def generate_new_rowcol(self, row, col):
assert col is not None, "Not enough columns!"
return row, col

def add_tikzfigure(self, **kwargs):
def add_line(
self,
x_data,
y_data,
layer=0,
subplot: LinePlot | None = None,
row: int | None = None,
col: int | None = None,
plot_type="plot",
**kwargs,
):
if row is not None and col is not None:
try:
subplot = self._subplot_matrix[row][col]
except KeyError:
raise ValueError("Invalid subplot position.")
else:
row, col = 0, 0
subplot = self._subplot_matrix[row][col]

if subplot is None:
row, col = self.generate_new_rowcol(row, col)
subplot = self.add_subplot(col=col, row=row)

subplot.add_line(
x_data=x_data,
y_data=y_data,
layer=layer,
plot_type=plot_type,
**kwargs,
)

def add_tikzfigure(
self,
col=None,
row=None,
label=None,
**kwargs,
):
"""
Adds a subplot to the figure.

Parameters:
**kwargs: Arbitrary keyword arguments.
- col (int): Column index for the subplot.
- row (int): Row index for the subplot.
- label (str): Label to identify the subplot.
"""
col = kwargs.get("col", None)
row = kwargs.get("row", None)
label = kwargs.get("label", None)

row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
tikz_figure = tf.TikzFigure(**kwargs)
tikz_figure = TikzFigure(
col=col,
row=row,
label=label,
**kwargs,
)
self._subplot_matrix[row][col] = tikz_figure

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = tikz_figure
self._subplots[(row, col)] = tikz_figure
else:
self.subplots[label] = tikz_figure
self._subplots[label] = tikz_figure
return tikz_figure

def add_subplot(self, **kwargs):
def add_subplot(
self,
col: int | None = None,
row: int | None = None,
figsize: tuple = (10, 6),
title: str | None = None,
caption: str | None = None,
description: str | None = None,
label: str | None = None,
grid: bool = False,
legend: bool = False,
xmin: float | int | None = None,
xmax: float | int | None = None,
ymin: float | int | None = None,
ymax: float | int | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
xscale: float | int = 1.0,
yscale: float | int = 1.0,
xshift: float | int = 0.0,
yshift: float | int = 0.0,
):
"""
Adds a subplot to the figure.

Expand All @@ -103,21 +165,32 @@ def add_subplot(self, **kwargs):
- row (int): Row index for the subplot.
- label (str): Label to identify the subplot.
"""
col = kwargs.get("col", None)
row = kwargs.get("row", None)
label = kwargs.get("label", None)

row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
line_plot = lp.LinePlot(**kwargs)
line_plot = LinePlot(
title=title,
grid=grid,
legend=legend,
xmin=xmin,
xmax=xmax,
ymin=ymin,
ymax=ymax,
xlabel=xlabel,
ylabel=ylabel,
xscale=xscale,
yscale=yscale,
xshift=xshift,
yshift=yshift,
)
self._subplot_matrix[row][col] = line_plot

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = line_plot
self._subplots[(row, col)] = line_plot
else:
self.subplots[label] = line_plot
self._subplots[label] = line_plot
return line_plot

def savefig(
Expand All @@ -136,7 +209,10 @@ def savefig(
for layer in self.layers:
layers.append(layer)
fig, axs = self.plot(
show=False, backend="matplotlib", savefig=True, layers=layers
show=False,
backend="matplotlib",
savefig=True,
layers=layers,
)
_fn = f"{filename_no_extension}_{layers}.{extension}"
fig.savefig(_fn)
Expand All @@ -153,25 +229,38 @@ def savefig(
else:

fig, axs = self.plot(
show=False, backend="matplotlib", savefig=True, layers=layers
show=False,
backend="matplotlib",
savefig=True,
layers=layers,
)
fig.savefig(full_filepath)
if verbose:
print(f"Saved {full_filepath}")

def plot(self, backend="matplotlib", show=True, savefig=False, layers=None):
def plot(self, backend="matplotlib", savefig=False, layers=None):
if backend == "matplotlib":
return self.plot_matplotlib(show=show, savefig=savefig, layers=layers)
return self.plot_matplotlib(savefig=savefig, layers=layers)
elif backend == "plotly":
self.plot_plotly(show=show, savefig=savefig)
return self.plot_plotly(savefig=savefig)
else:
raise ValueError("Invalid backend")

def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
def show(self, backend="matplotlib"):
if backend == "matplotlib":
self.plot(backend="matplotlib", savefig=False, layers=None)
self._matplotlib_fig.show()
elif backend == "plotly":
plot = self.plot_plotly(savefig=False)
else:
raise ValueError("Invalid backend")

def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
"""
Generate and optionally display the subplots.

Parameters:
filename (str, optional): Filename to save the figure.
show (bool): Whether to display the plot.
"""

tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)
Expand Down Expand Up @@ -205,17 +294,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):

for (row, col), subplot in self.subplots.items():
ax = axes[row][col]
# print(f"{subplot = }")
subplot.plot_matplotlib(ax, layers=layers)
# ax.set_title(f"Subplot ({row}, {col})")
ax.grid()
# Set caption, labels, etc., if needed
# plt.tight_layout()

if show:
plt.show()
# else:
# plt.close()
# Set caption, labels, etc., if needed
self._plotted = True
self._matplotlib_fig = fig
self._matplotlib_axes = axes
Expand All @@ -240,7 +323,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
fig_width, fig_height = self._figsize
else:
fig_width, fig_height = plt_utils.set_size(
width=self._width, ratio=self._ratio
width=self._width,
ratio=self._ratio,
)
# print(self._width, fig_width, fig_height)
# Create subplots
Expand Down Expand Up @@ -271,8 +355,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
fig.write_image(savefig)

# Show or return the figure
if show:
fig.show()
# if show:
# fig.show()
return fig

# Property getters
Expand Down
Loading
Loading