Skip to content

Commit 59a9573

Browse files
Added unit tests
1 parent 63239e3 commit 59a9573

File tree

10 files changed

+481
-10
lines changed

10 files changed

+481
-10
lines changed

CLAUDE.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@
1010
- After significant code changes run `make format` and `make typecheck` to make sure the code follows best practices
1111

1212
## Code Maintenance Guidelines
13-
- When there is a new plot type added in plot.py we need to update the mcp tool in server.py
13+
- When there is a new plot type added in plot.py we need to update the mcp tool in server.py
14+
15+
## Dependency Management
16+
- This project uses uv for managing dependencies. In case you need to add a dependency use the command `uv add <dependency>`
17+
- If the dependency is for developing, e.g. for testing use `uv add <dependency> --dev`
18+
- All of the configuration of the project should be centralized in the pyproject.toml file

Makefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
.PHONY: install format lint typecheck help
1+
.PHONY: install format lint typecheck test help
22

33
help:
44
@echo "Available commands:"
55
@echo " install - Install dependencies using uv"
66
@echo " format - Format code using ruff"
77
@echo " lint - Run linting using ruff"
88
@echo " typecheck - Run type checking using ty"
9+
@echo " test - Run tests using pytest"
910
@echo " help - Show this help message"
1011

1112
install:
1213
uv sync
1314

1415
format:
1516
uv run ruff format .
16-
uv run ruff check --fix .
1717

1818
lint:
1919
uv run ruff check .
2020

2121
typecheck:
2222
uv run ty check
23+
24+
test:
25+
uv run pytest

pyproject.toml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ requires-python = ">=3.13"
1414
classifiers = [
1515
"Development Status :: 3 - Alpha",
1616
"Intended Audience :: Developers",
17-
"License :: OSI Approved :: MIT License",
17+
"License :: OSI Approved :: Apache Software License",
1818
"Programming Language :: Python :: 3",
1919
"Programming Language :: Python :: 3.13",
2020
]
@@ -32,6 +32,7 @@ dependencies = [
3232

3333
[dependency-groups]
3434
dev = [
35+
"pytest>=8.4.1",
3536
"ruff>=0.12.5",
3637
"ty>=0.0.1a16",
3738
]
@@ -52,5 +53,17 @@ lint.select = [
5253
"I", # isort
5354
"C", # flake8-comprehensions
5455
"B", # flake8-bugbear
55-
"S", # flake8-bandit
5656
]
57+
lint.ignore = []
58+
59+
[tool.pytest.ini_options]
60+
minversion = "8.0"
61+
addopts = [
62+
"-ra",
63+
"--strict-config",
64+
"-v"
65+
]
66+
testpaths = ["tests"]
67+
python_files = ["test_*.py"]
68+
python_classes = ["Test*"]
69+
python_functions = ["test_*"]

src/plotting_mcp/plot.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _auto_rotate_labels(ax: plt.Axes, axis: Literal["x", "y"] = "x") -> None:
4545
ax.tick_params(axis=axis, labelrotation=90)
4646

4747

48-
def _create_world_map(ax: GeoAxes, df: pd.DataFrame, **kwargs) -> None:
48+
def _create_world_map_plot(ax: GeoAxes, df: pd.DataFrame, **kwargs) -> None:
4949
"""Create a world map with coordinate points."""
5050
# Add map features
5151
ax.add_feature(cfeature.COASTLINE)
@@ -135,13 +135,17 @@ def _create_pie_plot(ax: plt.Axes, df: pd.DataFrame, **kwargs) -> None:
135135
)
136136

137137

138-
def _create_matplotlib_plot( # noqa: C901
138+
def _create_plot( # noqa: C901
139139
df: pd.DataFrame, plot_type: str, **kwargs
140140
) -> tuple[plt.Figure, plt.Axes]:
141141
"""Create a plot using matplotlib/seaborn."""
142142
if df.empty:
143143
raise ValueError("CSV data is empty")
144144

145+
# Validate that the DataFrame contains no NaN values
146+
if df.isnull().any().any():
147+
raise ValueError("CSV data contains NaN/null values. Please ensure all data is complete.")
148+
145149
supported_plot_types = ["line", "bar", "pie", "worldmap"]
146150
if plot_type not in supported_plot_types:
147151
raise ValueError(
@@ -169,7 +173,7 @@ def _create_matplotlib_plot( # noqa: C901
169173
_create_pie_plot(ax, df, **kwargs)
170174
elif plot_type == "worldmap":
171175
# Cartopy doesn't return correct Axes type, so we ignore type checking
172-
_create_world_map(ax, df, **kwargs) # ty: ignore[invalid-argument-type]
176+
_create_world_map_plot(ax, df, **kwargs) # ty: ignore[invalid-argument-type]
173177

174178
# Auto-rotate x-axis labels if needed (not applicable for pie charts or world maps)
175179
if plot_type not in ["pie", "worldmap"]:
@@ -190,7 +194,7 @@ def _create_matplotlib_plot( # noqa: C901
190194

191195
def plot_to_bytes(df: pd.DataFrame, plot_type: str, **kwargs) -> bytes:
192196
"""Generate a plot and return it as bytes."""
193-
fig, _ = _create_matplotlib_plot(df, plot_type, **kwargs)
197+
fig, _ = _create_plot(df, plot_type, **kwargs)
194198
buffer = io.BytesIO()
195199
fig.savefig(buffer, format="png", bbox_inches="tight")
196200
plt.close(fig)
@@ -200,7 +204,7 @@ def plot_to_bytes(df: pd.DataFrame, plot_type: str, **kwargs) -> bytes:
200204

201205
def plot_and_show(df: pd.DataFrame, plot_type: str, **kwargs) -> None:
202206
"""Generate a plot and display it."""
203-
fig, _ = _create_matplotlib_plot(df, plot_type, **kwargs)
207+
fig, _ = _create_plot(df, plot_type, **kwargs)
204208
plt.show()
205209
plt.close(fig)
206210

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Tests package

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Pytest configuration and shared fixtures."""
2+
3+
import matplotlib
4+
5+
6+
def pytest_configure(config):
7+
"""Configure pytest with custom settings."""
8+
# Use non-interactive backend for matplotlib to avoid GUI issues in tests
9+
matplotlib.use("Agg")

tests/test_plot.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""Tests for plotting functionality."""
2+
3+
import matplotlib.pyplot as plt
4+
import pandas as pd
5+
import pytest
6+
from matplotlib.figure import Figure
7+
8+
from plotting_mcp.plot import (
9+
_auto_rotate_labels,
10+
_create_pie_plot,
11+
_create_plot,
12+
plot_to_bytes,
13+
)
14+
15+
16+
class TestAutoRotateLabels:
17+
"""Test the _auto_rotate_labels function."""
18+
19+
def test_auto_rotate_labels_x_axis(self):
20+
"""Test auto rotation for x-axis labels."""
21+
fig, ax = plt.subplots()
22+
# Create long labels that should trigger rotation
23+
long_labels = [f"Very Long Label {i}" for i in range(10)]
24+
ax.set_xticks(range(len(long_labels)))
25+
ax.set_xticklabels(long_labels)
26+
27+
_auto_rotate_labels(ax, axis="x")
28+
29+
# Check that rotation was applied
30+
for label in ax.get_xticklabels():
31+
assert label.get_rotation() == 90
32+
33+
plt.close(fig)
34+
35+
def test_auto_rotate_labels_y_axis(self):
36+
"""Test auto rotation for y-axis labels."""
37+
fig, ax = plt.subplots()
38+
# Create long labels that should trigger rotation
39+
long_labels = [f"Very Long Label {i}" for i in range(10)]
40+
ax.set_yticks(range(len(long_labels)))
41+
ax.set_yticklabels(long_labels)
42+
43+
_auto_rotate_labels(ax, axis="y")
44+
45+
# Check that rotation was applied
46+
for label in ax.get_yticklabels():
47+
assert label.get_rotation() == 90
48+
49+
plt.close(fig)
50+
51+
def test_auto_rotate_labels_short_labels(self):
52+
"""Test that short labels don't get rotated."""
53+
fig, ax = plt.subplots()
54+
# Create short labels that shouldn't trigger rotation
55+
short_labels = ["A", "B", "C", "D"]
56+
ax.set_xticks(range(len(short_labels)))
57+
ax.set_xticklabels(short_labels)
58+
59+
_auto_rotate_labels(ax, axis="x")
60+
61+
# Check that rotation was not applied (default rotation is 0)
62+
for label in ax.get_xticklabels():
63+
assert label.get_rotation() == 0
64+
65+
plt.close(fig)
66+
67+
def test_auto_rotate_labels_invalid_axis(self):
68+
"""Test that invalid axis raises ValueError."""
69+
fig, ax = plt.subplots()
70+
71+
with pytest.raises(ValueError, match="Axis must be 'x' or 'y'"):
72+
_auto_rotate_labels(ax, axis="z")
73+
74+
plt.close(fig)
75+
76+
def test_auto_rotate_labels_empty_labels(self):
77+
"""Test that empty labels don't cause errors."""
78+
fig, ax = plt.subplots()
79+
# No labels set
80+
81+
# Should not raise any errors
82+
_auto_rotate_labels(ax, axis="x")
83+
_auto_rotate_labels(ax, axis="y")
84+
85+
plt.close(fig)
86+
87+
88+
class TestCreatePiePlot:
89+
"""Test the _create_pie_plot function."""
90+
91+
def test_create_pie_plot_single_column_value_counts(self):
92+
"""Test pie plot with single column using value counts."""
93+
df = pd.DataFrame({"category": ["A", "B", "A", "C", "B", "A"]})
94+
fig, ax = plt.subplots()
95+
96+
_create_pie_plot(ax, df)
97+
98+
# Check that pie chart was created (wedges should exist)
99+
assert len(ax.patches) > 0
100+
plt.close(fig)
101+
102+
def test_create_pie_plot_two_columns(self):
103+
"""Test pie plot with two columns (labels and values)."""
104+
df = pd.DataFrame({"category": ["A", "B", "C"], "values": [30, 45, 25]})
105+
fig, ax = plt.subplots()
106+
107+
_create_pie_plot(ax, df)
108+
109+
# Check that pie chart was created with 3 wedges
110+
assert len(ax.patches) == 3
111+
plt.close(fig)
112+
113+
def test_create_pie_plot_too_many_columns(self):
114+
"""Test that pie plot with too many columns raises ValueError."""
115+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6], "col3": [7, 8, 9]})
116+
fig, ax = plt.subplots()
117+
118+
with pytest.raises(ValueError, match="Pie chart requires either one column"):
119+
_create_pie_plot(ax, df)
120+
121+
plt.close(fig)
122+
123+
def test_create_pie_plot_two_columns_with_labels_param(self):
124+
"""Test that pie plot with two columns rejects labels parameter."""
125+
df = pd.DataFrame({"category": ["A", "B", "C"], "values": [30, 45, 25]})
126+
fig, ax = plt.subplots()
127+
128+
with pytest.raises(ValueError, match="does not accept 'labels' parameter"):
129+
_create_pie_plot(ax, df, labels=["X", "Y", "Z"])
130+
131+
plt.close(fig)
132+
133+
134+
class TestCreateMatplotlibPlot:
135+
"""Test the _create_matplotlib_plot function."""
136+
137+
def test_create_line_plot(self):
138+
"""Test creation of line plot."""
139+
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [2, 4, 6, 8, 10]})
140+
141+
fig, ax = _create_plot(df, "line", x="x", y="y")
142+
143+
assert isinstance(fig, Figure)
144+
assert len(ax.lines) > 0 # Line plot should have lines
145+
plt.close(fig)
146+
147+
def test_create_bar_plot(self):
148+
"""Test creation of bar plot."""
149+
df = pd.DataFrame({"category": ["A", "B", "C", "D"], "values": [10, 15, 8, 12]})
150+
151+
fig, ax = _create_plot(df, "bar", x="category", y="values")
152+
153+
assert isinstance(fig, Figure)
154+
assert len(ax.patches) > 0 # Bar plot should have bars (patches)
155+
plt.close(fig)
156+
157+
def test_create_pie_plot(self):
158+
"""Test creation of pie plot through matplotlib function."""
159+
df = pd.DataFrame({"category": ["A", "B", "C"], "values": [30, 45, 25]})
160+
161+
fig, ax = _create_plot(df, "pie")
162+
163+
assert isinstance(fig, Figure)
164+
assert len(ax.patches) > 0 # Pie plot should have wedges (patches)
165+
plt.close(fig)
166+
167+
def test_empty_dataframe_raises_error(self):
168+
"""Test that empty DataFrame raises ValueError."""
169+
df = pd.DataFrame()
170+
171+
with pytest.raises(ValueError, match="CSV data is empty"):
172+
_create_plot(df, "line")
173+
174+
def test_unsupported_plot_type_raises_error(self):
175+
"""Test that unsupported plot type raises ValueError."""
176+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
177+
178+
with pytest.raises(ValueError, match="Unsupported plot type"):
179+
_create_plot(df, "scatter3d")
180+
181+
def test_plot_with_title_and_labels(self):
182+
"""Test plot creation with title and axis labels."""
183+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
184+
185+
fig, ax = _create_plot(
186+
df, "line", x="x", y="y", title="Test Plot", xlabel="X Axis", ylabel="Y Axis"
187+
)
188+
189+
assert ax.get_title() == "Test Plot"
190+
assert ax.get_xlabel() == "X Axis"
191+
assert ax.get_ylabel() == "Y Axis"
192+
plt.close(fig)
193+
194+
195+
class TestPlotToBytes:
196+
"""Test the plot_to_bytes function."""
197+
198+
def test_plot_to_bytes_returns_bytes(self):
199+
"""Test that plot_to_bytes returns bytes."""
200+
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [2, 4, 6, 8, 10]})
201+
202+
result = plot_to_bytes(df, "line", x="x", y="y")
203+
204+
assert isinstance(result, bytes)
205+
assert len(result) > 0
206+
# Check PNG header
207+
assert result.startswith(b"\x89PNG")
208+
209+
def test_plot_to_bytes_different_plot_types(self):
210+
"""Test plot_to_bytes with different plot types."""
211+
df_line = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
212+
df_bar = pd.DataFrame({"cat": ["A", "B", "C"], "val": [1, 2, 3]})
213+
df_pie = pd.DataFrame({"category": ["A", "B", "C"], "values": [30, 45, 25]})
214+
215+
line_bytes = plot_to_bytes(df_line, "line", x="x", y="y")
216+
bar_bytes = plot_to_bytes(df_bar, "bar", x="cat", y="val")
217+
pie_bytes = plot_to_bytes(df_pie, "pie")
218+
219+
assert all(isinstance(b, bytes) for b in [line_bytes, bar_bytes, pie_bytes])
220+
assert all(len(b) > 0 for b in [line_bytes, bar_bytes, pie_bytes])
221+
assert all(b.startswith(b"\x89PNG") for b in [line_bytes, bar_bytes, pie_bytes])

0 commit comments

Comments
 (0)