Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from maxplotlib.canvas.canvas import Canvas

__all__ = ["Canvas"]
33 changes: 24 additions & 9 deletions src/maxplotlib/backends/matplotlib/utils_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,12 @@ def setup_plotstyle(

def set_common_xlabel(self, xlabel="common X"):
self.fig.text(
0.5, -0.075, xlabel, va="center", ha="center", fontsize=self.fontsize
0.5,
-0.075,
xlabel,
va="center",
ha="center",
fontsize=self.fontsize,
)
# fig.text(0.04, 0.5, 'common Y', va='center', ha='center', rotation='vertical', fontsize=rcParams['axes.labelsize'])

Expand Down Expand Up @@ -438,7 +443,9 @@ def scale_axis(
i0 = int(xmin / delta)
i1 = int(xmax / delta + 1)
locs = np.arange(
includepoint - width, includepoint + width + delta, delta
includepoint - width,
includepoint + width + delta,
delta,
)
locs = locs[locs >= xmin - 1e-12]
locs = locs[locs <= xmax + 1e-12]
Expand Down Expand Up @@ -473,7 +480,9 @@ def scale_axis(
i0 = int(ymin / delta)
i1 = int(ymax / delta + 1)
locs = np.arange(
includepoint - width, includepoint + width + delta, delta
includepoint - width,
includepoint + width + delta,
delta,
)
locs = locs[locs >= ymin - 1e-12]
locs = locs[locs <= ymax + 1e-12]
Expand Down Expand Up @@ -507,7 +516,10 @@ def adjustFigAspect(self, aspect=1):
else:
ylim /= aspect
self.fig.subplots_adjust(
left=0.5 - xlim, right=0.5 + xlim, bottom=0.5 - ylim, top=0.5 + ylim
left=0.5 - xlim,
right=0.5 + xlim,
bottom=0.5 - ylim,
top=0.5 + ylim,
)

def add_figure_label(
Expand Down Expand Up @@ -619,14 +631,16 @@ def savefig(
# self.fig.savefig(self.directory + filename + '.' + format,bbox_inches='tight', transparent=False)
if tight_layout:
self.fig.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)
else:
self.fig.savefig(self.directory + filename + "." + format)
elif format == "pgf":
# Save pgf figure
self.fig.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)

# Replace pgf figure colors with colorlet
Expand Down Expand Up @@ -672,15 +686,16 @@ def savefig(
else:
try:
plt.savefig(
self.directory + filename + "." + format, bbox_inches="tight"
self.directory + filename + "." + format,
bbox_inches="tight",
)
except Exception as e:
print(
"ERROR: Could not save figure: "
+ self.directory
+ filename
+ "."
+ format
+ format,
)
print(e)

Expand All @@ -690,7 +705,7 @@ def savefig(
for format in formats:
if format in imgcat_formats:
f.write(
"imgcat " + self.directory + filename + "." + format + "\n"
"imgcat " + self.directory + filename + "." + format + "\n",
)

if print_imgcat and ("png" in formats or "pdf" in formats):
Expand Down
158 changes: 121 additions & 37 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from plotly.subplots import make_subplots

import maxplotlib.backends.matplotlib.utils as plt_utils
import maxplotlib.subfigure.line_plot as lp
import maxplotlib.subfigure.tikz_figure as tf
from maxplotlib.subfigure.line_plot import LinePlot
from maxplotlib.subfigure.tikz_figure import TikzFigure


class Canvas:
Expand Down Expand Up @@ -37,11 +37,15 @@ def __init__(self, **kwargs):

# Dictionary to store lines for each subplot
# Key: (row, col), Value: list of lines with their data and kwargs
self.subplots = {}
self._subplots = {}
self._num_subplots = 0

self._subplot_matrix = [[None] * self.ncols for _ in range(self.nrows)]

@property
def subplots(self):
return self._subplots

@property
def layers(self):
layers = []
Expand All @@ -66,34 +70,92 @@ def generate_new_rowcol(self, row, col):
assert col is not None, "Not enough columns!"
return row, col

def add_tikzfigure(self, **kwargs):
def add_line(
self,
x_data,
y_data,
layer=0,
subplot: LinePlot | None = None,
row: int | None = None,
col: int | None = None,
plot_type="plot",
**kwargs,
):
if row is not None and col is not None:
try:
subplot = self._subplot_matrix[row][col]
except KeyError:
raise ValueError("Invalid subplot position.")
else:
row, col = 0, 0
subplot = self._subplot_matrix[row][col]

if subplot is None:
row, col = self.generate_new_rowcol(row, col)
subplot = self.add_subplot(col=col, row=row)

subplot.add_line(
x_data=x_data,
y_data=y_data,
layer=layer,
plot_type=plot_type,
**kwargs,
)

def add_tikzfigure(
self,
col=None,
row=None,
label=None,
**kwargs,
):
"""
Adds a subplot to the figure.

Parameters:
**kwargs: Arbitrary keyword arguments.
- col (int): Column index for the subplot.
- row (int): Row index for the subplot.
- label (str): Label to identify the subplot.
"""
col = kwargs.get("col", None)
row = kwargs.get("row", None)
label = kwargs.get("label", None)

row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
tikz_figure = tf.TikzFigure(**kwargs)
tikz_figure = TikzFigure(
col=col,
row=row,
label=label,
**kwargs,
)
self._subplot_matrix[row][col] = tikz_figure

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = tikz_figure
self._subplots[(row, col)] = tikz_figure
else:
self.subplots[label] = tikz_figure
self._subplots[label] = tikz_figure
return tikz_figure

def add_subplot(self, **kwargs):
def add_subplot(
self,
col: int | None = None,
row: int | None = None,
figsize: tuple = (10, 6),
title: str | None = None,
caption: str | None = None,
description: str | None = None,
label: str | None = None,
grid: bool = False,
legend: bool = False,
xmin: float | int | None = None,
xmax: float | int | None = None,
ymin: float | int | None = None,
ymax: float | int | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
xscale: float | int = 1.0,
yscale: float | int = 1.0,
xshift: float | int = 0.0,
yshift: float | int = 0.0,
):
"""
Adds a subplot to the figure.

Expand All @@ -103,21 +165,32 @@ def add_subplot(self, **kwargs):
- row (int): Row index for the subplot.
- label (str): Label to identify the subplot.
"""
col = kwargs.get("col", None)
row = kwargs.get("row", None)
label = kwargs.get("label", None)

row, col = self.generate_new_rowcol(row, col)

# Initialize the LinePlot for the given subplot position
line_plot = lp.LinePlot(**kwargs)
line_plot = LinePlot(
title=title,
grid=grid,
legend=legend,
xmin=xmin,
xmax=xmax,
ymin=ymin,
ymax=ymax,
xlabel=xlabel,
ylabel=ylabel,
xscale=xscale,
yscale=yscale,
xshift=xshift,
yshift=yshift,
)
self._subplot_matrix[row][col] = line_plot

# Store the LinePlot instance by its position for easy access
if label is None:
self.subplots[(row, col)] = line_plot
self._subplots[(row, col)] = line_plot
else:
self.subplots[label] = line_plot
self._subplots[label] = line_plot
return line_plot

def savefig(
Expand All @@ -136,7 +209,10 @@ def savefig(
for layer in self.layers:
layers.append(layer)
fig, axs = self.plot(
show=False, backend="matplotlib", savefig=True, layers=layers
show=False,
backend="matplotlib",
savefig=True,
layers=layers,
)
_fn = f"{filename_no_extension}_{layers}.{extension}"
fig.savefig(_fn)
Expand All @@ -153,25 +229,38 @@ def savefig(
else:

fig, axs = self.plot(
show=False, backend="matplotlib", savefig=True, layers=layers
show=False,
backend="matplotlib",
savefig=True,
layers=layers,
)
fig.savefig(full_filepath)
if verbose:
print(f"Saved {full_filepath}")

def plot(self, backend="matplotlib", show=True, savefig=False, layers=None):
def plot(self, backend="matplotlib", savefig=False, layers=None):
if backend == "matplotlib":
return self.plot_matplotlib(show=show, savefig=savefig, layers=layers)
return self.plot_matplotlib(savefig=savefig, layers=layers)
elif backend == "plotly":
self.plot_plotly(show=show, savefig=savefig)
return self.plot_plotly(savefig=savefig)
else:
raise ValueError(f"Invalid backend: {backend}")

def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
def show(self, backend="matplotlib"):
if backend == "matplotlib":
self.plot(backend="matplotlib", savefig=False, layers=None)
self._matplotlib_fig.show()
elif backend == "plotly":
plot = self.plot_plotly(savefig=False)
else:
raise ValueError("Invalid backend")

def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
"""
Generate and optionally display the subplots.

Parameters:
filename (str, optional): Filename to save the figure.
show (bool): Whether to display the plot.
"""

tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)
Expand Down Expand Up @@ -205,17 +294,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):

for (row, col), subplot in self.subplots.items():
ax = axes[row][col]
# print(f"{subplot = }")
subplot.plot_matplotlib(ax, layers=layers)
# ax.set_title(f"Subplot ({row}, {col})")
ax.grid()
# Set caption, labels, etc., if needed
# plt.tight_layout()

if show:
plt.show()
# else:
# plt.close()
# Set caption, labels, etc., if needed
self._plotted = True
self._matplotlib_fig = fig
self._matplotlib_axes = axes
Expand All @@ -240,7 +323,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
fig_width, fig_height = self._figsize
else:
fig_width, fig_height = plt_utils.set_size(
width=self._width, ratio=self._ratio
width=self._width,
ratio=self._ratio,
)
# print(self._width, fig_width, fig_height)
# Create subplots
Expand Down Expand Up @@ -271,8 +355,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
fig.write_image(savefig)

# Show or return the figure
if show:
fig.show()
# if show:
# fig.show()
return fig

# Property getters
Expand Down
Loading
Loading