From 52ae02bdb5a66981ff9909bf410e1253867a6688 Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Thu, 11 Sep 2025 23:29:19 -0700 Subject: [PATCH] Add reasoning support for Cohere chat generator --- .../generators/cohere/chat/chat_generator.py | 139 +++++++++- .../cohere/tests/test_chat_generator.py | 254 +++++++++++++++++- 2 files changed, 387 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 278e47a1e..2437698e9 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -1,9 +1,10 @@ import json +import re from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, get_args from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message -from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall +from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall from haystack.dataclasses.streaming_chunk import ( AsyncStreamingCallbackT, FinishReason, @@ -202,11 +203,20 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: ) ) + # Extract reasoning from content if present, even with tool calls + reasoning_content = None + if chat_response.message.content and hasattr(chat_response.message.content[0], "text"): + raw_content = chat_response.message.content[0].text + reasoning_content, _ = _extract_reasoning_from_response(raw_content) + # Create message with tool plan as text and tool calls in the format Haystack expects tool_plan = chat_response.message.tool_plan or "" - message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) + message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content) elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"): - message = ChatMessage.from_assistant(chat_response.message.content[0].text) + raw_content = chat_response.message.content[0].text + # Extract reasoning content if present + reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content) + message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content) else: # Handle the case where neither tool_calls nor content exists logger.warning(f"Received empty response from Cohere API: {chat_response.message}") @@ -350,6 +360,125 @@ def _convert_cohere_chunk_to_streaming_chunk( ) +def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]: + """ + Extract reasoning content from Cohere's response if present. + + Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content + in various formats. This function attempts to identify and extract such content. + + :param response_text: The raw response text from Cohere + :returns: A tuple of (ReasoningContent or None, cleaned_response_text) + """ + if not response_text or not isinstance(response_text, str): + return None, response_text + + # Check for reasoning markers that Cohere might use + + # Pattern 1: Look for thinking/reasoning tags + thinking_patterns = [ + r"(.*?)", + r"(.*?)", + r"## Reasoning\s*\n(.*?)(?=\n## |$)", + r"## Thinking\s*\n(.*?)(?=\n## |$)", + ] + + for pattern in thinking_patterns: + match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE) + if match: + reasoning_text = match.group(1).strip() + cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() + # Apply minimum length threshold for tag-based reasoning + min_reasoning_length = 30 + if len(reasoning_text) > min_reasoning_length: + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content + else: + # Content too short, but still clean the tags + return None, cleaned_content + + # Pattern 2: Look for step-by-step reasoning at start + lines = response_text.split("\n") + reasoning_lines = [] + content_lines = [] + found_separator = False + + for i, line in enumerate(lines): + stripped_line = line.strip() + # Look for reasoning indicators at the beginning of lines (more precise) + if ( + stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve")) + or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning")) + or ( + len(stripped_line) > 0 + and stripped_line.endswith(":") + and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower()) + ) + ): + # Look for a clear separator to determine where reasoning ends + reasoning_end = len(lines) # Default to end of text + for j in range(i + 1, len(lines)): + next_line = lines[j].strip() + if next_line.startswith( + ("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer") + ): + reasoning_end = j + break + + reasoning_lines = lines[:reasoning_end] + content_lines = lines[reasoning_end:] + found_separator = True + break + # Stop looking after first few lines + max_lines_to_check = 10 + if i > max_lines_to_check: + break + + if found_separator and reasoning_lines: + reasoning_text = "\n".join(reasoning_lines).strip() + cleaned_content = "\n".join(content_lines).strip() + min_reasoning_length = 30 + if len(reasoning_text) > min_reasoning_length: # Minimum threshold + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content + + # No reasoning detected + return None, response_text + + +def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage: + """ + Convert streaming chunks to ChatMessage with reasoning extraction support. + + This is a custom version of the core utility function that adds reasoning content + extraction for Cohere responses. + """ + # Use the core utility to get the base ChatMessage + base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks) + + # Extract text content to check for reasoning + if not base_message.text: + return base_message + + # Use the text property for reasoning extraction + combined_text = base_message.text + + # Extract reasoning if present + reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text) + + if reasoning_content is None: + # No reasoning found, return original message + return base_message + + # Create new message with reasoning support + new_message = ChatMessage.from_assistant( + text=cleaned_text, + reasoning=reasoning_content, + tool_calls=base_message.tool_calls, + meta=base_message.meta, + ) + + return new_message + + def _parse_streaming_response( response: Iterator[StreamedChatResponseV2], model: str, @@ -381,7 +510,7 @@ def _parse_streaming_response( chunks.append(streaming_chunk) streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message(chunks=chunks) + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) async def _parse_async_streaming_response( @@ -409,7 +538,7 @@ async def _parse_async_streaming_response( chunks.append(streaming_chunk) await streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message(chunks=chunks) + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) @component diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index 4378ee016..b90e41158 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -7,13 +7,14 @@ from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ToolCall +from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ReasoningContent, ToolCall from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator from haystack_integrations.components.generators.cohere.chat.chat_generator import ( + _extract_reasoning_from_response, _format_message, ) @@ -656,3 +657,254 @@ def test_live_run_multimodal(self): assert len(results["replies"]) == 1 assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 + + +class TestReasoningExtraction: + """Test the reasoning extraction functionality.""" + + def test_extract_reasoning_with_thinking_tags(self): + """Test extraction of reasoning from tags.""" + response_text = """ +I need to calculate the area of a circle. +The formula is π * r². +Given radius is 5, so area = π * 25 = 78.54 + + +The area of a circle with radius 5 is approximately 78.54 square units.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "calculate the area of a circle" in reasoning.reasoning_text + assert "formula is π * r²" in reasoning.reasoning_text + assert "area = π * 25 = 78.54" in reasoning.reasoning_text + assert cleaned.strip() == "The area of a circle with radius 5 is approximately 78.54 square units." + + def test_extract_reasoning_with_reasoning_tags(self): + """Test extraction of reasoning from tags.""" + response_text = """ +Let me think about this step by step: +1. First, I need to understand the problem +2. Then identify the key variables +3. Apply the appropriate formula + + +Based on my analysis, here's the solution.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "step by step" in reasoning.reasoning_text + assert "understand the problem" in reasoning.reasoning_text + assert "key variables" in reasoning.reasoning_text + assert cleaned.strip() == "Based on my analysis, here's the solution." + + def test_extract_reasoning_with_step_by_step_headers(self): + """Test extraction of reasoning from step-by-step format.""" + response_text = """## My reasoning: +Step 1: Analyze the input data +Step 2: Identify patterns +Step 3: Apply the algorithm + +## Solution: +The final answer is 42.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "Step 1: Analyze the input data" in reasoning.reasoning_text + assert "Step 2: Identify patterns" in reasoning.reasoning_text + assert "Step 3: Apply the algorithm" in reasoning.reasoning_text + assert cleaned.strip() == "## Solution:\nThe final answer is 42." + + def test_extract_reasoning_no_reasoning_present(self): + """Test that no reasoning is extracted when none is present.""" + response_text = "This is a simple response without any reasoning content." + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is None + assert cleaned == response_text + + def test_extract_reasoning_short_reasoning_ignored(self): + """Test that very short reasoning content is ignored.""" + response_text = """ +OK + + +The answer is yes.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is None # Too short, should be ignored + assert cleaned.strip() == "The answer is yes." + + def test_extract_reasoning_with_let_me_think(self): + """Test extraction of reasoning starting with 'Let me think'.""" + response_text = """Let me think through this carefully: + +First, I need to consider the constraints of the problem. The user is asking about quantum +mechanics, which requires understanding wave-particle duality. + +Second, I should explain the fundamental principles clearly. + +Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "think through this carefully" in reasoning.reasoning_text + assert "constraints of the problem" in reasoning.reasoning_text + assert "wave-particle duality" in reasoning.reasoning_text + assert cleaned.strip() == ( + "Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale." + ) + + +class TestCohereChatGeneratorReasoning: + """Integration tests for reasoning functionality in CohereChatGenerator.""" + + @pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_reasoning_with_command_a_reasoning_model(self): + """Test reasoning extraction with Command A Reasoning model.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", + generation_kwargs={"thinking": True}, # Enable reasoning + ) + + messages = [ + ChatMessage.from_user("Solve this math problem step by step: What is the area of a circle with radius 7?") + ] + + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check if reasoning was extracted + if reply.reasoning: + assert isinstance(reply.reasoning, ReasoningContent) + assert len(reply.reasoning.reasoning_text) > 50 # Should have substantial reasoning + + # The reasoning should contain mathematical thinking + reasoning_lower = reply.reasoning.reasoning_text.lower() + assert any(word in reasoning_lower for word in ["area", "circle", "radius", "formula", "π", "pi"]) + + # Check the main response content + assert len(reply.text) > 0 + response_lower = reply.text.lower() + assert any(word in response_lower for word in ["area", "153.94", "154", "square"]) + + def test_reasoning_with_mock_response(self): + """Test reasoning extraction with mocked Cohere response.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") + ) + + # Mock the Cohere client response + mock_response = MagicMock() + mock_response.message.content = [ + MagicMock( + text=""" +I need to solve for the area of a circle. +The formula is A = πr² +With radius 7: A = π * 7² = π * 49 ≈ 153.94 + + +The area of a circle with radius 7 is approximately 153.94 square units.""" + ) + ] + mock_response.message.tool_calls = None + mock_response.message.citations = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What is the area of a circle with radius 7?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check reasoning extraction + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "formula is A = πr²" in reply.reasoning.reasoning_text + assert "π * 49 ≈ 153.94" in reply.reasoning.reasoning_text + + # Check cleaned content + assert reply.text.strip() == "The area of a circle with radius 7 is approximately 153.94 square units." + + def test_reasoning_with_tool_calls_compatibility(self): + """Test that reasoning works with tool calls.""" + weather_tool = Tool( + name="weather", + description="Get weather for a city", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + function=weather, + ) + + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", tools=[weather_tool], api_key=Secret.from_token("fake-api-key") + ) + + # Mock response with both reasoning and tool calls + mock_response = MagicMock() + mock_response.message.content = [ + MagicMock( + text=""" +The user is asking about weather in Paris. I should use the weather tool to get accurate information. + + +I'll check the weather in Paris for you.""" + ) + ] + + # Mock tool call + mock_tool_call = MagicMock() + mock_tool_call.function.name = "weather" + mock_tool_call.function.arguments = '{"city": "Paris"}' + mock_tool_call.id = "call_123" + mock_response.message.tool_calls = [mock_tool_call] + mock_response.message.tool_plan = "I'll check the weather in Paris for you." + mock_response.message.citations = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + + # Check reasoning extraction + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "weather tool" in reply.reasoning.reasoning_text + + # Check tool calls are preserved + assert reply.tool_calls is not None + assert len(reply.tool_calls) == 1 + assert reply.tool_calls[0].tool_name == "weather" + + # Check cleaned content + assert "I'll check the weather in Paris" in reply.text