Skip to content

Commit ccb648c

Browse files
committed
Refactor codebase_search to return error strings instead of dictionaries and add corresponding async test coverage.
1 parent 5bb8f1a commit ccb648c

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

src/tests/test_search_tool.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,89 @@ async def test_codebase_search_returns_xml_string(mock_get_api_key):
6161

6262
# Verify it contains expected XML structure
6363
assert "<results>" in result, "Should contain results tag"
64-
assert "<search_result" in result, "Should contain search_result tag"
64+
assert "<search_result" in result, "Should contain search_result tag"
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_codebase_search_empty_query_returns_error_string():
69+
"""Test that empty query returns an error string (not dict)."""
70+
# Create mock context
71+
ctx = MagicMock(spec=Context)
72+
ctx.info = AsyncMock()
73+
ctx.warning = AsyncMock()
74+
ctx.error = AsyncMock()
75+
76+
# Create mock context with proper structure
77+
mock_codealive_context = MagicMock()
78+
mock_codealive_context.base_url = "https://app.codealive.ai"
79+
ctx.request_context.lifespan_context = mock_codealive_context
80+
81+
# Call with empty query
82+
result = await codebase_search(
83+
ctx=ctx,
84+
query="",
85+
data_source_ids=["test_id"],
86+
mode="auto",
87+
include_content=False
88+
)
89+
90+
# Verify result is a string (not a dict)
91+
assert isinstance(result, str), "Error should be returned as a string"
92+
assert "<error>" in result, "Error string should contain <error> tag"
93+
assert "Query cannot be empty" in result, "Should contain error message"
94+
95+
96+
@pytest.mark.asyncio
97+
@patch('tools.search.get_api_key_from_context')
98+
async def test_codebase_search_api_error_returns_error_string(mock_get_api_key):
99+
"""Test that API errors return an error string (not dict)."""
100+
import httpx
101+
102+
# Mock the API key function
103+
mock_get_api_key.return_value = "test_key"
104+
105+
# Create mock context
106+
ctx = MagicMock(spec=Context)
107+
ctx.info = AsyncMock()
108+
ctx.warning = AsyncMock()
109+
ctx.error = AsyncMock()
110+
111+
# Create mock response that raises 404
112+
mock_response = MagicMock()
113+
mock_response.status_code = 404
114+
mock_response.text = "Not found"
115+
116+
def raise_404():
117+
raise httpx.HTTPStatusError(
118+
"Not found",
119+
request=MagicMock(),
120+
response=mock_response
121+
)
122+
123+
mock_response.raise_for_status = raise_404
124+
125+
# Create mock client that returns the error response
126+
mock_client = AsyncMock()
127+
mock_client.get.return_value = mock_response
128+
129+
# Create mock context with proper structure
130+
mock_codealive_context = MagicMock()
131+
mock_codealive_context.client = mock_client
132+
mock_codealive_context.base_url = "https://app.codealive.ai"
133+
134+
ctx.request_context.lifespan_context = mock_codealive_context
135+
ctx.request_context.headers = {"authorization": "Bearer test_key"}
136+
137+
# Call codebase_search
138+
result = await codebase_search(
139+
ctx=ctx,
140+
query="test query",
141+
data_source_ids=["invalid_id"],
142+
mode="auto",
143+
include_content=False
144+
)
145+
146+
# Verify result is a string (not a dict)
147+
assert isinstance(result, str), "Error should be returned as a string"
148+
assert "<error>" in result, "Error string should contain <error> tag"
149+
assert "404" in result, "Should contain error details"

src/tools/search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Search tool implementation."""
22

3-
from typing import Any, Dict, List, Optional, Union
3+
from typing import List, Optional, Union
44
from urllib.parse import urljoin
55

66
import httpx
@@ -16,7 +16,7 @@ async def codebase_search(
1616
data_source_ids: Optional[Union[str, List[str]]] = None,
1717
mode: str = "auto",
1818
include_content: bool = False
19-
) -> Union[str, Dict[str, Any]]:
19+
) -> str:
2020
"""
2121
Use `codebase_search` tool to search for code in the codebase.
2222
@@ -99,7 +99,7 @@ async def codebase_search(
9999

100100
# Validate inputs
101101
if not query or not query.strip():
102-
return {"error": "Query cannot be empty. Please provide a search term, function name, or description of the code you're looking for."}
102+
return "<error>Query cannot be empty. Please provide a search term, function name, or description of the code you're looking for.</error>"
103103

104104
if not data_source_ids or len(data_source_ids) == 0:
105105
await ctx.info("No data source IDs provided. If the API key has exactly one assigned data source, that will be used as default.")
@@ -161,4 +161,4 @@ async def codebase_search(
161161
error_msg = await handle_api_error(ctx, e, "code search")
162162
if isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 404:
163163
error_msg = f"Error: Not found (404): One or more data sources could not be found. Check your data_source_ids."
164-
return {"error": error_msg}
164+
return f"<error>{error_msg}</error>"

0 commit comments

Comments
 (0)