Skip to content

Commit 5d41305

Browse files
author
deec
committed
Increase unit test coverage to 99% and cover MCP guard branches
1 parent 90c96ca commit 5d41305

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

tests/test_server.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,176 @@ async def test_get_schema_for_table_uses_param_query(self, mock_exec):
245245
# execute_query_async(query, params=[table_name]) -> params should be in kwargs
246246
self.assertEqual(kwargs.get("params"), ["Sales"])
247247

248+
async def test_query_connx_rejects_semicolons(self):
249+
out = await mod.query_connx("SELECT 1; SELECT 2")
250+
self.assertIn("error", out)
251+
self.assertIn("single sql statement", out["error"].lower())
252+
253+
async def test_query_connx_rejects_non_select(self):
254+
out = await mod.query_connx("UPDATE T SET A=1")
255+
self.assertIn("error", out)
256+
self.assertIn("only select", out["error"].lower())
257+
258+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
259+
async def test_query_connx_returns_error_dict_on_value_error(self, mock_exec):
260+
mock_exec.side_effect = ValueError("no db")
261+
out = await mod.query_connx("SELECT * FROM T")
262+
self.assertIn("error", out)
263+
self.assertIn("no db", out["error"].lower())
264+
265+
async def test_update_connx_rejects_when_writes_disabled(self):
266+
# CI will hit this branch unless you patch CONNX_ALLOW_WRITES=True
267+
out = await mod.update_connx("update", "UPDATE T SET A=1")
268+
self.assertIn("error", out)
269+
self.assertIn("writes are disabled", out["error"].lower())
270+
271+
@patch(f"{MODULE_UNDER_TEST}.CONNX_ALLOW_WRITES", True)
272+
async def test_update_connx_invalid_operation_when_writes_enabled(self):
273+
# Covers the "invalid operation" branch (only reachable if writes enabled)
274+
out = await mod.update_connx("merge", "UPDATE T SET A=1")
275+
self.assertIn("error", out)
276+
self.assertIn("invalid operation", out["error"].lower())
277+
278+
@patch(f"{MODULE_UNDER_TEST}.CONNX_ALLOW_WRITES", True)
279+
async def test_update_connx_rejects_semicolons_when_writes_enabled(self):
280+
# Covers "single statement" guard inside update tool
281+
out = await mod.update_connx("update", "UPDATE T SET A=1; UPDATE T SET A=2")
282+
self.assertIn("error", out)
283+
self.assertIn("single sql statement", out["error"].lower())
284+
285+
@patch(f"{MODULE_UNDER_TEST}.CONNX_ALLOW_WRITES", True)
286+
@patch(f"{MODULE_UNDER_TEST}.execute_update_async")
287+
async def test_update_connx_error_when_writes_enabled(self, mock_exec):
288+
# Covers update_connx except ValueError -> {"error": ...}
289+
mock_exec.side_effect = ValueError("bad update")
290+
out = await mod.update_connx("delete", "DELETE FROM T")
291+
self.assertIn("error", out)
292+
self.assertIn("bad update", out["error"].lower())
293+
294+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
295+
async def test_get_schema_returns_error_dict_on_value_error(self, mock_exec):
296+
# Covers get_schema except ValueError -> {"error": ...}
297+
mock_exec.side_effect = ValueError("schema fail")
298+
out = await mod.get_schema()
299+
self.assertIn("error", out)
300+
self.assertIn("schema fail", out["error"].lower())
301+
302+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
303+
async def test_get_schema_for_table_returns_error_dict_on_value_error(self, mock_exec):
304+
# Covers get_schema_for_table except ValueError -> {"error": ...}
305+
mock_exec.side_effect = ValueError("schema table fail")
306+
out = await mod.get_schema_for_table("CUSTOMERS_VSAM")
307+
self.assertIn("error", out)
308+
self.assertIn("schema table fail", out["error"].lower())
309+
310+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
311+
async def test_find_customers_error_returns_error_dict(self, mock_exec):
312+
# Covers find_customers except ValueError -> {"error": ...}
313+
mock_exec.side_effect = ValueError("boom")
314+
out = await mod.find_customers("VA")
315+
self.assertIn("error", out)
316+
self.assertIn("boom", out["error"].lower())
317+
318+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
319+
async def test_find_customers_truncates_results(self, mock_exec):
320+
# Covers truncation branch
321+
mock_exec.return_value = [{"CUSTOMERID": "1"}] * 105
322+
out = await mod.find_customers("VA", max_rows=100)
323+
self.assertEqual(out["count"], 100)
324+
self.assertTrue(out["truncated"])
325+
self.assertEqual(len(out["results"]), 100)
326+
327+
class TestConfig(unittest.TestCase):
328+
def test_assert_config_raises_when_missing(self):
329+
with patch.dict(os.environ, {}, clear=True):
330+
with self.assertRaises(RuntimeError) as ctx:
331+
mod._assert_config()
332+
self.assertIn("missing required config values", str(ctx.exception).lower())
333+
334+
class TestFindCustomers(unittest.IsolatedAsyncioTestCase):
335+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
336+
async def test_find_customers_builds_query_and_params_state_only(self, mock_exec):
337+
mock_exec.return_value = [{"CUSTOMERID": "A"}]
338+
339+
out = await mod.find_customers("Virginia") # should normalize to VA
340+
self.assertIn("results", out)
341+
self.assertEqual(out["count"], 1)
342+
343+
args, kwargs = mock_exec.call_args
344+
sql_sent = args[0]
345+
params_sent = kwargs.get("params")
346+
347+
self.assertIn("FROM daea_Mainframe_VSAM.dbo.CUSTOMERS_VSAM", sql_sent)
348+
self.assertIn("CUSTOMERSTATE", sql_sent.upper())
349+
self.assertEqual(params_sent, ["VA"])
350+
351+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
352+
async def test_find_customers_includes_city_filter_when_provided(self, mock_exec):
353+
mock_exec.return_value = [{"CUSTOMERID": "A"}]
354+
355+
out = await mod.find_customers("VA", city="Richmond")
356+
self.assertEqual(out["count"], 1)
357+
358+
args, kwargs = mock_exec.call_args
359+
sql_sent = args[0].upper()
360+
params_sent = kwargs.get("params")
361+
362+
self.assertIn("CUSTOMERCITY", sql_sent)
363+
self.assertEqual(params_sent, ["VA", "Richmond"])
364+
365+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
366+
async def test_find_customers_truncates_results(self, mock_exec):
367+
mock_exec.return_value = [{"CUSTOMERID": str(i)} for i in range(200)]
368+
369+
out = await mod.find_customers("VA", max_rows=10)
370+
self.assertEqual(out["count"], 10)
371+
self.assertTrue(out["truncated"])
372+
373+
class TestFindCustomers(unittest.IsolatedAsyncioTestCase):
374+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
375+
async def test_find_customers_builds_query_and_params_state_only(self, mock_exec):
376+
mock_exec.return_value = [{"CUSTOMERID": "A"}]
377+
378+
out = await mod.find_customers("Virginia") # should normalize to VA
379+
self.assertIn("results", out)
380+
self.assertEqual(out["count"], 1)
381+
382+
args, kwargs = mock_exec.call_args
383+
sql_sent = args[0]
384+
params_sent = kwargs.get("params")
385+
386+
self.assertIn("FROM daea_Mainframe_VSAM.dbo.CUSTOMERS_VSAM", sql_sent)
387+
self.assertIn("CUSTOMERSTATE", sql_sent.upper())
388+
self.assertEqual(params_sent, ["VA"])
389+
390+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
391+
async def test_find_customers_includes_city_filter_when_provided(self, mock_exec):
392+
mock_exec.return_value = [{"CUSTOMERID": "A"}]
393+
394+
out = await mod.find_customers("VA", city="Richmond")
395+
self.assertEqual(out["count"], 1)
396+
397+
args, kwargs = mock_exec.call_args
398+
sql_sent = args[0].upper()
399+
params_sent = kwargs.get("params")
400+
401+
self.assertIn("CUSTOMERCITY", sql_sent)
402+
self.assertEqual(params_sent, ["VA", "Richmond"])
403+
404+
@patch(f"{MODULE_UNDER_TEST}.execute_query_async")
405+
async def test_find_customers_truncates_results(self, mock_exec):
406+
mock_exec.return_value = [{"CUSTOMERID": str(i)} for i in range(200)]
407+
408+
out = await mod.find_customers("VA", max_rows=10)
409+
self.assertEqual(out["count"], 10)
410+
self.assertTrue(out["truncated"])
411+
412+
class TestSqlFingerprint(unittest.TestCase):
413+
def test_sql_fingerprint_is_stable_and_short(self):
414+
a = mod._sql_fingerprint("SELECT 1")
415+
b = mod._sql_fingerprint("SELECT 1")
416+
self.assertEqual(a, b)
417+
self.assertEqual(len(a), 12)
248418

249419
if __name__ == "__main__":
250420
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)