|
1 | 1 | """Tests for tools module.""" |
2 | 2 |
|
| 3 | +import os |
3 | 4 | import pytest |
4 | 5 | from unittest.mock import AsyncMock, patch, Mock |
5 | 6 | import httpx |
|
8 | 9 | handle_pia_search_content_facets, |
9 | 10 | handle_pia_search_titles, |
10 | 11 | handle_pia_search_titles_facets, |
| 12 | + handle_pia_search_content_gao, |
| 13 | + handle_pia_search_content_oig, |
| 14 | + handle_pia_search_content_crs, |
| 15 | + handle_pia_search_content_doj, |
| 16 | + handle_pia_search_content_congress, |
| 17 | + handle_pia_search_content_executive_orders, |
| 18 | + handle_search, |
| 19 | + handle_fetch, |
11 | 20 | ) |
12 | 21 | from pia_mcp_server.config import Settings |
13 | 22 |
|
|
18 | 27 | async def test_pia_search_content_no_api_key(): |
19 | 28 | """Test PIA content search without API key.""" |
20 | 29 | with patch.object(Settings, "_get_api_key_from_args", return_value=None): |
21 | | - result = await handle_pia_search_content({"query": "test"}) |
| 30 | + with patch.dict(os.environ, {}, clear=True): # Clear all environment variables |
| 31 | + result = await handle_pia_search_content({"query": "test"}) |
22 | 32 |
|
23 | | - assert len(result) == 1 |
24 | | - assert "PIA API key is required" in result[0].text |
| 33 | + assert len(result) == 1 |
| 34 | + assert "PIA API key is required" in result[0].text |
25 | 35 |
|
26 | 36 |
|
27 | 37 | @pytest.mark.asyncio |
@@ -196,10 +206,11 @@ async def test_pia_search_content_http_error(): |
196 | 206 | async def test_pia_search_content_facets_no_api_key(): |
197 | 207 | """Test PIA search facets without API key.""" |
198 | 208 | with patch.object(Settings, "_get_api_key_from_args", return_value=None): |
199 | | - result = await handle_pia_search_content_facets({"query": "test"}) |
| 209 | + with patch.dict(os.environ, {}, clear=True): # Clear all environment variables |
| 210 | + result = await handle_pia_search_content_facets({"query": "test"}) |
200 | 211 |
|
201 | | - assert len(result) == 1 |
202 | | - assert "PIA API key is required" in result[0].text |
| 212 | + assert len(result) == 1 |
| 213 | + assert "PIA API key is required" in result[0].text |
203 | 214 |
|
204 | 215 |
|
205 | 216 | @pytest.mark.asyncio |
@@ -453,3 +464,282 @@ async def test_pia_search_content_facets_empty_filter(): |
453 | 464 |
|
454 | 465 | assert len(result) == 1 |
455 | 466 | assert "SourceDocumentDataSource" in result[0].text |
| 467 | + |
| 468 | + |
| 469 | +# Agency-specific search tool tests |
| 470 | +@pytest.mark.asyncio |
| 471 | +async def test_pia_search_content_gao_success(): |
| 472 | + """Test successful PIA GAO content search.""" |
| 473 | + mock_response = { |
| 474 | + "jsonrpc": "2.0", |
| 475 | + "id": 1, |
| 476 | + "result": { |
| 477 | + "documents": [ |
| 478 | + {"title": "GAO Report", "id": "gao-123", "data_source": "GAO"} |
| 479 | + ], |
| 480 | + "total": 1, |
| 481 | + }, |
| 482 | + } |
| 483 | + |
| 484 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 485 | + with patch("httpx.AsyncClient") as mock_client: |
| 486 | + mock_response_obj = Mock() |
| 487 | + mock_response_obj.json.return_value = mock_response |
| 488 | + mock_response_obj.raise_for_status.return_value = None |
| 489 | + |
| 490 | + mock_client_instance = AsyncMock() |
| 491 | + mock_client_instance.post.return_value = mock_response_obj |
| 492 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 493 | + |
| 494 | + result = await handle_pia_search_content_gao({"query": "audit"}) |
| 495 | + |
| 496 | + assert len(result) == 1 |
| 497 | + assert "GAO Report" in result[0].text |
| 498 | + |
| 499 | + |
| 500 | +@pytest.mark.asyncio |
| 501 | +async def test_pia_search_content_oig_success(): |
| 502 | + """Test successful PIA OIG content search.""" |
| 503 | + mock_response = { |
| 504 | + "jsonrpc": "2.0", |
| 505 | + "id": 1, |
| 506 | + "result": { |
| 507 | + "documents": [ |
| 508 | + {"title": "OIG Investigation", "id": "oig-123", "data_source": "OIG"} |
| 509 | + ], |
| 510 | + "total": 1, |
| 511 | + }, |
| 512 | + } |
| 513 | + |
| 514 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 515 | + with patch("httpx.AsyncClient") as mock_client: |
| 516 | + mock_response_obj = Mock() |
| 517 | + mock_response_obj.json.return_value = mock_response |
| 518 | + mock_response_obj.raise_for_status.return_value = None |
| 519 | + |
| 520 | + mock_client_instance = AsyncMock() |
| 521 | + mock_client_instance.post.return_value = mock_response_obj |
| 522 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 523 | + |
| 524 | + result = await handle_pia_search_content_oig({"query": "oversight"}) |
| 525 | + |
| 526 | + assert len(result) == 1 |
| 527 | + assert "OIG Investigation" in result[0].text |
| 528 | + |
| 529 | + |
| 530 | +@pytest.mark.asyncio |
| 531 | +async def test_pia_search_content_crs_success(): |
| 532 | + """Test successful PIA CRS content search.""" |
| 533 | + mock_response = { |
| 534 | + "jsonrpc": "2.0", |
| 535 | + "id": 1, |
| 536 | + "result": { |
| 537 | + "documents": [ |
| 538 | + {"title": "CRS Report", "id": "crs-123", "data_source": "CRS"} |
| 539 | + ], |
| 540 | + "total": 1, |
| 541 | + }, |
| 542 | + } |
| 543 | + |
| 544 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 545 | + with patch("httpx.AsyncClient") as mock_client: |
| 546 | + mock_response_obj = Mock() |
| 547 | + mock_response_obj.json.return_value = mock_response |
| 548 | + mock_response_obj.raise_for_status.return_value = None |
| 549 | + |
| 550 | + mock_client_instance = AsyncMock() |
| 551 | + mock_client_instance.post.return_value = mock_response_obj |
| 552 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 553 | + |
| 554 | + result = await handle_pia_search_content_crs({"query": "research"}) |
| 555 | + |
| 556 | + assert len(result) == 1 |
| 557 | + assert "CRS Report" in result[0].text |
| 558 | + |
| 559 | + |
| 560 | +@pytest.mark.asyncio |
| 561 | +async def test_pia_search_content_doj_success(): |
| 562 | + """Test successful PIA DOJ content search.""" |
| 563 | + mock_response = { |
| 564 | + "jsonrpc": "2.0", |
| 565 | + "id": 1, |
| 566 | + "result": { |
| 567 | + "documents": [ |
| 568 | + { |
| 569 | + "title": "DOJ Press Release", |
| 570 | + "id": "doj-123", |
| 571 | + "data_source": "Department of Justice", |
| 572 | + } |
| 573 | + ], |
| 574 | + "total": 1, |
| 575 | + }, |
| 576 | + } |
| 577 | + |
| 578 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 579 | + with patch("httpx.AsyncClient") as mock_client: |
| 580 | + mock_response_obj = Mock() |
| 581 | + mock_response_obj.json.return_value = mock_response |
| 582 | + mock_response_obj.raise_for_status.return_value = None |
| 583 | + |
| 584 | + mock_client_instance = AsyncMock() |
| 585 | + mock_client_instance.post.return_value = mock_response_obj |
| 586 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 587 | + |
| 588 | + result = await handle_pia_search_content_doj({"query": "enforcement"}) |
| 589 | + |
| 590 | + assert len(result) == 1 |
| 591 | + assert "DOJ Press Release" in result[0].text |
| 592 | + |
| 593 | + |
| 594 | +@pytest.mark.asyncio |
| 595 | +async def test_pia_search_content_congress_success(): |
| 596 | + """Test successful PIA Congress content search.""" |
| 597 | + mock_response = { |
| 598 | + "jsonrpc": "2.0", |
| 599 | + "id": 1, |
| 600 | + "result": { |
| 601 | + "documents": [ |
| 602 | + { |
| 603 | + "title": "Congressional Bill", |
| 604 | + "id": "congress-123", |
| 605 | + "data_source": "Congress.gov", |
| 606 | + } |
| 607 | + ], |
| 608 | + "total": 1, |
| 609 | + }, |
| 610 | + } |
| 611 | + |
| 612 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 613 | + with patch("httpx.AsyncClient") as mock_client: |
| 614 | + mock_response_obj = Mock() |
| 615 | + mock_response_obj.json.return_value = mock_response |
| 616 | + mock_response_obj.raise_for_status.return_value = None |
| 617 | + |
| 618 | + mock_client_instance = AsyncMock() |
| 619 | + mock_client_instance.post.return_value = mock_response_obj |
| 620 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 621 | + |
| 622 | + result = await handle_pia_search_content_congress({"query": "legislation"}) |
| 623 | + |
| 624 | + assert len(result) == 1 |
| 625 | + assert "Congressional Bill" in result[0].text |
| 626 | + |
| 627 | + |
| 628 | +@pytest.mark.asyncio |
| 629 | +async def test_pia_search_content_executive_orders_success(): |
| 630 | + """Test successful PIA Executive Orders content search.""" |
| 631 | + mock_response = { |
| 632 | + "jsonrpc": "2.0", |
| 633 | + "id": 1, |
| 634 | + "result": { |
| 635 | + "documents": [ |
| 636 | + { |
| 637 | + "title": "Executive Order 12345", |
| 638 | + "id": "eo-123", |
| 639 | + "data_source": "Federal Register", |
| 640 | + } |
| 641 | + ], |
| 642 | + "total": 1, |
| 643 | + }, |
| 644 | + } |
| 645 | + |
| 646 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 647 | + with patch("httpx.AsyncClient") as mock_client: |
| 648 | + mock_response_obj = Mock() |
| 649 | + mock_response_obj.json.return_value = mock_response |
| 650 | + mock_response_obj.raise_for_status.return_value = None |
| 651 | + |
| 652 | + mock_client_instance = AsyncMock() |
| 653 | + mock_client_instance.post.return_value = mock_response_obj |
| 654 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 655 | + |
| 656 | + result = await handle_pia_search_content_executive_orders( |
| 657 | + {"query": "cybersecurity"} |
| 658 | + ) |
| 659 | + |
| 660 | + assert len(result) == 1 |
| 661 | + assert "Executive Order 12345" in result[0].text |
| 662 | + |
| 663 | + |
| 664 | +@pytest.mark.asyncio |
| 665 | +async def test_fetch_success(): |
| 666 | + """Test successful document fetch.""" |
| 667 | + mock_response = { |
| 668 | + "jsonrpc": "2.0", |
| 669 | + "id": 1, |
| 670 | + "result": { |
| 671 | + "id": "doc-123", |
| 672 | + "title": "Test Document", |
| 673 | + "content": "Full document content here", |
| 674 | + "url": "https://example.com/doc-123", |
| 675 | + }, |
| 676 | + } |
| 677 | + |
| 678 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 679 | + with patch("httpx.AsyncClient") as mock_client: |
| 680 | + mock_response_obj = Mock() |
| 681 | + mock_response_obj.json.return_value = mock_response |
| 682 | + mock_response_obj.raise_for_status.return_value = None |
| 683 | + |
| 684 | + mock_client_instance = AsyncMock() |
| 685 | + mock_client_instance.post.return_value = mock_response_obj |
| 686 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 687 | + |
| 688 | + result = await handle_fetch({"id": "doc-123"}) |
| 689 | + |
| 690 | + assert len(result) == 1 |
| 691 | + assert "Test Document" in result[0].text |
| 692 | + assert "Full document content here" in result[0].text |
| 693 | + |
| 694 | + |
| 695 | +@pytest.mark.asyncio |
| 696 | +async def test_agency_tools_no_api_key(): |
| 697 | + """Test agency-specific tools without API key.""" |
| 698 | + tools_to_test = [ |
| 699 | + (handle_pia_search_content_gao, {"query": "test"}), |
| 700 | + (handle_pia_search_content_oig, {"query": "test"}), |
| 701 | + (handle_pia_search_content_crs, {"query": "test"}), |
| 702 | + (handle_pia_search_content_doj, {"query": "test"}), |
| 703 | + (handle_pia_search_content_congress, {"query": "test"}), |
| 704 | + (handle_pia_search_content_executive_orders, {"query": "test"}), |
| 705 | + (handle_fetch, {"id": "test-123"}), |
| 706 | + ] |
| 707 | + |
| 708 | + for tool_handler, args in tools_to_test: |
| 709 | + with patch.object(Settings, "_get_api_key_from_args", return_value=None): |
| 710 | + with patch.dict( |
| 711 | + os.environ, {}, clear=True |
| 712 | + ): # Clear all environment variables |
| 713 | + result = await tool_handler(args) |
| 714 | + assert len(result) == 1 |
| 715 | + assert "PIA API key is required" in result[0].text |
| 716 | + |
| 717 | + |
| 718 | +@pytest.mark.asyncio |
| 719 | +async def test_agency_tools_http_error(): |
| 720 | + """Test agency-specific tools with HTTP error.""" |
| 721 | + tools_to_test = [ |
| 722 | + (handle_pia_search_content_gao, {"query": "test"}), |
| 723 | + (handle_pia_search_content_oig, {"query": "test"}), |
| 724 | + (handle_pia_search_content_crs, {"query": "test"}), |
| 725 | + (handle_pia_search_content_doj, {"query": "test"}), |
| 726 | + (handle_pia_search_content_congress, {"query": "test"}), |
| 727 | + (handle_pia_search_content_executive_orders, {"query": "test"}), |
| 728 | + (handle_fetch, {"id": "test-123"}), |
| 729 | + ] |
| 730 | + |
| 731 | + for tool_handler, args in tools_to_test: |
| 732 | + with patch.object(Settings, "_get_api_key_from_args", return_value="test_key"): |
| 733 | + with patch("httpx.AsyncClient") as mock_client: |
| 734 | + mock_client_instance = AsyncMock() |
| 735 | + mock_client_instance.post.side_effect = httpx.HTTPStatusError( |
| 736 | + "Server Error", |
| 737 | + request=Mock(), |
| 738 | + response=Mock(status_code=500, text="Server Error"), |
| 739 | + ) |
| 740 | + mock_client.return_value.__aenter__.return_value = mock_client_instance |
| 741 | + |
| 742 | + result = await tool_handler(args) |
| 743 | + |
| 744 | + assert len(result) == 1 |
| 745 | + assert "HTTP Error 500" in result[0].text |
0 commit comments