Skip to content

Commit e262e65

Browse files
LGDiMaggioclaude
andcommitted
test: add unit tests for RAG, document_reader, signal_loader, CLI, models, DOCX reports
Also fix chunk_text() infinite loop when chunk_overlap >= chunk_size (clamp overlap), and skip parquet test when pyarrow/fastparquet is not installed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 28d657e commit e262e65

File tree

7 files changed

+1175
-0
lines changed

7 files changed

+1175
-0
lines changed

src/rag.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def chunk_text(
6969
- ``index`` : sequential chunk number
7070
- ``start`` / ``end`` : character offsets in the original text
7171
"""
72+
if chunk_overlap >= chunk_size:
73+
chunk_overlap = max(0, chunk_size - 1)
7274
chunks: list[dict[str, Any]] = []
7375
start = 0
7476
idx = 0

tests/test_cli_transport.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
Tests for CLI argument parsing and transport configuration.
3+
4+
Covers:
5+
- argparse defaults (transport=stdio, host=127.0.0.1, port=8000)
6+
- Explicit flag values (--transport sse, --host, --port)
7+
- Short flag aliases (-t, -p)
8+
- Environment variable fallbacks (MCP_TRANSPORT, MCP_HOST, MCP_PORT)
9+
- Invalid transport choice → error
10+
- Invalid port (non-numeric) → error
11+
- CLI args override env vars
12+
"""
13+
14+
import argparse
15+
import os
16+
import pytest
17+
from unittest.mock import patch, MagicMock
18+
19+
20+
def _build_parser(env_transport="stdio", env_host="127.0.0.1", env_port="8000"):
21+
"""Replicate the argparse setup from main() for isolated testing."""
22+
parser = argparse.ArgumentParser(description="Predictive Maintenance MCP Server")
23+
parser.add_argument(
24+
"--transport", "-t",
25+
choices=["stdio", "sse", "streamable-http"],
26+
default=env_transport,
27+
help="Transport protocol",
28+
)
29+
parser.add_argument(
30+
"--host",
31+
default=env_host,
32+
help="Bind address for SSE/HTTP",
33+
)
34+
parser.add_argument(
35+
"--port", "-p",
36+
type=int,
37+
default=int(env_port),
38+
help="Port for SSE/HTTP transport",
39+
)
40+
return parser
41+
42+
43+
# ── Default values ─────────────────────────────────────────────────────────
44+
45+
class TestCLIDefaults:
46+
47+
def test_default_transport_stdio(self):
48+
args = _build_parser().parse_args([])
49+
assert args.transport == "stdio"
50+
51+
def test_default_host(self):
52+
args = _build_parser().parse_args([])
53+
assert args.host == "127.0.0.1"
54+
55+
def test_default_port(self):
56+
args = _build_parser().parse_args([])
57+
assert args.port == 8000
58+
59+
60+
# ── Explicit flags ─────────────────────────────────────────────────────────
61+
62+
class TestCLIExplicitFlags:
63+
64+
def test_transport_sse(self):
65+
args = _build_parser().parse_args(["--transport", "sse"])
66+
assert args.transport == "sse"
67+
68+
def test_transport_streamable_http(self):
69+
args = _build_parser().parse_args(["--transport", "streamable-http"])
70+
assert args.transport == "streamable-http"
71+
72+
def test_host_custom(self):
73+
args = _build_parser().parse_args(["--host", "0.0.0.0"])
74+
assert args.host == "0.0.0.0"
75+
76+
def test_port_custom(self):
77+
args = _build_parser().parse_args(["--port", "9090"])
78+
assert args.port == 9090
79+
80+
def test_short_flag_transport(self):
81+
args = _build_parser().parse_args(["-t", "sse"])
82+
assert args.transport == "sse"
83+
84+
def test_short_flag_port(self):
85+
args = _build_parser().parse_args(["-p", "7777"])
86+
assert args.port == 7777
87+
88+
def test_all_flags_combined(self):
89+
args = _build_parser().parse_args([
90+
"-t", "sse", "--host", "10.0.0.1", "-p", "3000"
91+
])
92+
assert args.transport == "sse"
93+
assert args.host == "10.0.0.1"
94+
assert args.port == 3000
95+
96+
97+
# ── Invalid arguments ──────────────────────────────────────────────────────
98+
99+
class TestCLIInvalidArgs:
100+
101+
def test_invalid_transport_rejected(self):
102+
with pytest.raises(SystemExit):
103+
_build_parser().parse_args(["--transport", "websocket"])
104+
105+
def test_non_numeric_port_rejected(self):
106+
with pytest.raises(SystemExit):
107+
_build_parser().parse_args(["--port", "abc"])
108+
109+
110+
# ── Environment variable fallbacks ─────────────────────────────────────────
111+
112+
class TestCLIEnvVars:
113+
114+
def test_mcp_transport_env(self, monkeypatch):
115+
monkeypatch.setenv("MCP_TRANSPORT", "sse")
116+
env_t = os.environ.get("MCP_TRANSPORT", "stdio")
117+
args = _build_parser(env_transport=env_t).parse_args([])
118+
assert args.transport == "sse"
119+
120+
def test_mcp_host_env(self, monkeypatch):
121+
monkeypatch.setenv("MCP_HOST", "192.168.1.100")
122+
env_h = os.environ.get("MCP_HOST", "127.0.0.1")
123+
args = _build_parser(env_host=env_h).parse_args([])
124+
assert args.host == "192.168.1.100"
125+
126+
def test_mcp_port_env(self, monkeypatch):
127+
monkeypatch.setenv("MCP_PORT", "5555")
128+
env_p = os.environ.get("MCP_PORT", "8000")
129+
args = _build_parser(env_port=env_p).parse_args([])
130+
assert args.port == 5555
131+
132+
def test_cli_overrides_env(self, monkeypatch):
133+
monkeypatch.setenv("MCP_TRANSPORT", "sse")
134+
monkeypatch.setenv("MCP_PORT", "5555")
135+
env_t = os.environ.get("MCP_TRANSPORT", "stdio")
136+
env_p = os.environ.get("MCP_PORT", "8000")
137+
args = _build_parser(env_transport=env_t, env_port=env_p).parse_args(
138+
["--transport", "streamable-http", "--port", "9999"]
139+
)
140+
assert args.transport == "streamable-http"
141+
assert args.port == 9999
142+
143+
144+
# ── main() integration (smoke test) ───────────────────────────────────────
145+
146+
def test_main_help_exits_cleanly():
147+
"""Verify --help works without crashing (exercises the real main parser)."""
148+
import sys
149+
from unittest.mock import patch as _patch
150+
151+
with _patch.object(sys, 'argv', ['test', '--help']):
152+
with pytest.raises(SystemExit) as exc_info:
153+
from predictive_maintenance_mcp.machinery_diagnostics_server import main
154+
main()
155+
assert exc_info.value.code == 0

tests/test_document_reader.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Tests for the Document Reader module.
3+
4+
Covers:
5+
- Bearing characteristic frequency calculations (BPFO, BPFI, BSF, FTF)
6+
- PDF text extraction (basic, missing file, OCR fallback)
7+
- Edge cases: zero contact angle, high contact angle, zero RPM
8+
"""
9+
10+
import math
11+
import pytest
12+
from pathlib import Path
13+
14+
from predictive_maintenance_mcp.document_reader import (
15+
calculate_bearing_frequencies,
16+
extract_text_from_pdf,
17+
HAS_PDF,
18+
HAS_OCR,
19+
)
20+
21+
22+
# ── Bearing frequency calculations ─────────────────────────────────────────
23+
24+
class TestBearingFrequencies:
25+
"""Tests against known bearing geometry values."""
26+
27+
def test_standard_deep_groove_6205(self):
28+
"""SKF 6205: Z=9, Bd=7.94mm, Pd=34.55mm, alpha=0, RPM=1797."""
29+
freqs = calculate_bearing_frequencies(
30+
num_balls=9,
31+
ball_diameter_mm=7.94,
32+
pitch_diameter_mm=34.55,
33+
contact_angle_deg=0.0,
34+
shaft_speed_rpm=1797,
35+
)
36+
shaft_freq = 1797 / 60.0 # 29.95 Hz
37+
38+
assert "BPFO" in freqs
39+
assert "BPFI" in freqs
40+
assert "BSF" in freqs
41+
assert "FTF" in freqs
42+
assert "shaft_freq_hz" in freqs
43+
assert abs(freqs["shaft_freq_hz"] - shaft_freq) < 0.01
44+
45+
# BPFO ~ (Z/2)*fs*(1 - Bd/Pd) = 4.5 * 29.95 * (1-0.2298) ≈ 103.8
46+
assert 80 < freqs["BPFO"] < 110
47+
# BPFI > BPFO always (inner race spins faster relative to rolling elements)
48+
assert freqs["BPFI"] > freqs["BPFO"]
49+
# FTF < shaft frequency
50+
assert freqs["FTF"] < shaft_freq
51+
52+
def test_all_frequencies_positive(self):
53+
freqs = calculate_bearing_frequencies(
54+
num_balls=8,
55+
ball_diameter_mm=12.0,
56+
pitch_diameter_mm=50.0,
57+
contact_angle_deg=0.0,
58+
shaft_speed_rpm=3000,
59+
)
60+
assert all(freqs[k] > 0 for k in ("BPFO", "BPFI", "BSF", "FTF"))
61+
62+
def test_angular_contact_bearing(self):
63+
"""Non-zero contact angle (angular contact bearing)."""
64+
freqs_zero = calculate_bearing_frequencies(
65+
num_balls=16, ball_diameter_mm=12.7,
66+
pitch_diameter_mm=71.5, contact_angle_deg=0.0,
67+
shaft_speed_rpm=2000,
68+
)
69+
freqs_15 = calculate_bearing_frequencies(
70+
num_balls=16, ball_diameter_mm=12.7,
71+
pitch_diameter_mm=71.5, contact_angle_deg=15.0,
72+
shaft_speed_rpm=2000,
73+
)
74+
# cos(15°) < cos(0°), so d_ratio*cos(alpha) is smaller
75+
# → BPFO should increase, BPFI should decrease compared to alpha=0
76+
# Actually: BPFO = (Z/2)*fs*(1 - d_ratio*cos(a)) → larger with smaller cos
77+
assert freqs_15["BPFO"] > freqs_zero["BPFO"]
78+
assert freqs_15["BPFI"] < freqs_zero["BPFI"]
79+
80+
def test_high_contact_angle(self):
81+
"""45° contact angle (thrust bearing)."""
82+
freqs = calculate_bearing_frequencies(
83+
num_balls=12, ball_diameter_mm=10.0,
84+
pitch_diameter_mm=60.0, contact_angle_deg=45.0,
85+
shaft_speed_rpm=1500,
86+
)
87+
assert all(freqs[k] > 0 for k in ("BPFO", "BPFI", "BSF", "FTF"))
88+
89+
def test_frequency_ratios(self):
90+
"""BPFI/BPFO ratio should be > 1 for typical bearings."""
91+
freqs = calculate_bearing_frequencies(
92+
num_balls=8, ball_diameter_mm=10.0,
93+
pitch_diameter_mm=50.0, shaft_speed_rpm=1500,
94+
)
95+
assert freqs["BPFI"] / freqs["BPFO"] > 1.0
96+
97+
def test_input_parameters_returned(self):
98+
"""Verify input parameters are echoed in result dict."""
99+
freqs = calculate_bearing_frequencies(
100+
num_balls=9, ball_diameter_mm=7.94,
101+
pitch_diameter_mm=34.55, contact_angle_deg=0.0,
102+
shaft_speed_rpm=1797,
103+
)
104+
params = freqs.get("input_parameters", {})
105+
assert params.get("num_balls") == 9
106+
assert params.get("shaft_speed_rpm") == 1797
107+
108+
def test_different_speeds(self):
109+
"""Frequencies should scale linearly with RPM."""
110+
freqs_1000 = calculate_bearing_frequencies(
111+
num_balls=8, ball_diameter_mm=10.0,
112+
pitch_diameter_mm=50.0, shaft_speed_rpm=1000,
113+
)
114+
freqs_2000 = calculate_bearing_frequencies(
115+
num_balls=8, ball_diameter_mm=10.0,
116+
pitch_diameter_mm=50.0, shaft_speed_rpm=2000,
117+
)
118+
# Double RPM → double all frequencies
119+
assert abs(freqs_2000["BPFO"] / freqs_1000["BPFO"] - 2.0) < 0.01
120+
assert abs(freqs_2000["BPFI"] / freqs_1000["BPFI"] - 2.0) < 0.01
121+
assert abs(freqs_2000["BSF"] / freqs_1000["BSF"] - 2.0) < 0.01
122+
123+
124+
# ── PDF text extraction ────────────────────────────────────────────────────
125+
126+
class TestExtractTextFromPdf:
127+
128+
@pytest.mark.skipif(not HAS_PDF, reason="pypdf not installed")
129+
def test_missing_file_raises(self):
130+
with pytest.raises(FileNotFoundError):
131+
extract_text_from_pdf(Path("nonexistent_file.pdf"))
132+
133+
@pytest.mark.skipif(not HAS_PDF, reason="pypdf not installed")
134+
def test_real_catalog_if_available(self):
135+
"""Extract text from a real bearing catalog in the repo."""
136+
catalogs_dir = Path(__file__).parent.parent / "resources" / "bearing_catalogs"
137+
pdfs = list(catalogs_dir.glob("*.pdf")) if catalogs_dir.exists() else []
138+
if not pdfs:
139+
pytest.skip("No PDF catalogs found in resources/")
140+
141+
text = extract_text_from_pdf(pdfs[0], max_pages=3)
142+
assert isinstance(text, str)
143+
assert len(text) > 50 # Should have some content
144+
145+
@pytest.mark.skipif(not HAS_PDF, reason="pypdf not installed")
146+
def test_max_pages_limits_extraction(self):
147+
"""max_pages should limit processing."""
148+
catalogs_dir = Path(__file__).parent.parent / "resources" / "bearing_catalogs"
149+
pdfs = list(catalogs_dir.glob("*.pdf")) if catalogs_dir.exists() else []
150+
if not pdfs:
151+
pytest.skip("No PDF catalogs found in resources/")
152+
153+
text_1 = extract_text_from_pdf(pdfs[0], max_pages=1)
154+
text_3 = extract_text_from_pdf(pdfs[0], max_pages=3)
155+
# More pages should produce more (or equal) text
156+
assert len(text_3) >= len(text_1)
157+
158+
def test_no_pypdf_raises_import_error(self):
159+
"""When pypdf not installed, should raise ImportError."""
160+
if HAS_PDF:
161+
pytest.skip("pypdf IS installed — can't test missing path")
162+
with pytest.raises(ImportError, match="pypdf"):
163+
extract_text_from_pdf(Path("test.pdf"))

0 commit comments

Comments
 (0)