Skip to content

Commit 3c3e1c6

Browse files
committed
feat: added tests for cot interface
1 parent 7d2232d commit 3c3e1c6

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

agentic_rag/local_rag_agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,17 @@ def _process_query_with_cot(self, query: str) -> Dict[str, Any]:
358358
logger.info("Falling back to general response")
359359
return self._generate_general_response(query)
360360

361+
# Handle string response from synthesis
362+
if isinstance(synthesis_result, str):
363+
return {
364+
"answer": synthesis_result,
365+
"reasoning_steps": reasoning_steps,
366+
"context": context
367+
}
368+
369+
# Handle dictionary response
361370
return {
362-
"answer": synthesis_result["answer"],
371+
"answer": synthesis_result.get("answer", synthesis_result) if isinstance(synthesis_result, dict) else synthesis_result,
363372
"reasoning_steps": reasoning_steps,
364373
"context": context
365374
}

agentic_rag/tests/test_cot_chat.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ def test_cot_chat():
6767
logger.info("Raw response received")
6868
debug_response_structure(raw_response, "Raw response: ")
6969

70+
# Verify response structure
71+
if not isinstance(raw_response, dict):
72+
logger.error(f"Unexpected response type: {type(raw_response)}")
73+
raise TypeError(f"Expected dict response, got {type(raw_response)}")
74+
75+
required_keys = ["answer", "reasoning_steps", "context"]
76+
missing_keys = [key for key in required_keys if key not in raw_response]
77+
if missing_keys:
78+
logger.error(f"Missing required keys in response: {missing_keys}")
79+
raise KeyError(f"Response missing required keys: {missing_keys}")
80+
7081
# Process through chat function
7182
logger.info("Processing through chat function...")
7283
result = chat(
@@ -91,9 +102,21 @@ def test_cot_chat():
91102
# Save debug information to file
92103
debug_info = {
93104
"test_message": test_message,
94-
"raw_response": str(raw_response),
95-
"final_result": str(result),
96-
"history": str(history)
105+
"raw_response": {
106+
"type": str(type(raw_response)),
107+
"keys": list(raw_response.keys()) if isinstance(raw_response, dict) else None,
108+
"content": str(raw_response)
109+
},
110+
"final_result": {
111+
"type": str(type(result)),
112+
"length": len(result) if isinstance(result, list) else None,
113+
"content": str(result)
114+
},
115+
"history": {
116+
"type": str(type(history)),
117+
"length": len(history),
118+
"content": str(history)
119+
}
97120
}
98121

99122
with open("cot_chat_debug.json", "w") as f:

0 commit comments

Comments
 (0)