diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py b/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py index 04ec6a856..32dbe8399 100644 --- a/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py +++ b/libs/ktem/ktem/reasoning/prompt_optimization/decompose_question.py @@ -1,4 +1,6 @@ +import json import logging +import re from ktem.llms.manager import llms from ktem.reasoning.prompt_optimization.rewrite_question import RewriteQuestionPipeline @@ -10,6 +12,16 @@ logger = logging.getLogger(__name__) +def _is_tool_not_supported_error(error: Exception) -> bool: + """Check if the error indicates the model does not support tools.""" + error_msg = str(error).lower() + return ( + "does not support tools" in error_msg + or "tool use is not supported" in error_msg + or "tools are not supported" in error_msg + ) + + class SubQuery(BaseModel): """Search over a database of insurance rulebooks or financial reports""" @@ -40,6 +52,20 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline): "If there are acronyms or words you are not familiar with, " "do not try to rephrase them." ) + # Fallback prompt for models that don't support tools + DECOMPOSE_FALLBACK_PROMPT_TEMPLATE = ( + "You are an expert at converting user complex questions into sub questions. " + "Given a user question, break it down into the most specific sub" + " questions you can (at most 3) " + "which will help you answer the original question. " + "Each sub question should be about a single concept/fact/idea. " + "If there are acronyms or words you are not familiar with, " + "do not try to rephrase them.\n\n" + "Output your sub-questions as a JSON array of objects, where each object has " + 'a "sub_query" field. Example:\n' + '[{"sub_query": "What is X?"}, {"sub_query": "How does Y work?"}]\n\n' + "Only output the JSON array, no other text." + ) prompt_template: str = DECOMPOSE_SYSTEM_PROMPT_TEMPLATE def create_prompt(self, question): @@ -62,9 +88,59 @@ def create_prompt(self, question): return messages, llm_kwargs + def _run_with_fallback(self, question: str) -> list: + """Fallback method for models that don't support tools. + + Uses plain text prompting to decompose the question. + """ + messages = [ + SystemMessage(content=self.DECOMPOSE_FALLBACK_PROMPT_TEMPLATE), + HumanMessage(content=question), + ] + + result = self.llm(messages) + text = result.text.strip() + + # Try to parse the response as JSON + sub_queries = [] + try: + # Try to extract JSON array from the response + # Handle cases where model adds extra text around the JSON + json_match = re.search(r"\[.*\]", text, re.DOTALL) + if json_match: + parsed = json.loads(json_match.group()) + for item in parsed: + if isinstance(item, dict) and "sub_query" in item: + sub_queries.append(Document(content=item["sub_query"])) + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Failed to parse fallback response as JSON: {e}") + # If JSON parsing fails, try to extract questions from the text + # by looking for numbered items or question marks + lines = text.split("\n") + for line in lines: + line = line.strip() + # Remove common prefixes like "1.", "- ", "* " + line = re.sub(r"^[\d]+[.)\]]\s*", "", line) + line = re.sub(r"^[-*•]\s*", "", line) + if line and ("?" in line or len(line) > 10): + sub_queries.append(Document(content=line)) + + return sub_queries[:3] # Limit to 3 sub-questions + def run(self, question: str) -> list: # type: ignore messages, llm_kwargs = self.create_prompt(question) - result = self.llm(messages, **llm_kwargs) + + try: + result = self.llm(messages, **llm_kwargs) + except Exception as e: + if _is_tool_not_supported_error(e): + logger.warning( + f"Model does not support tools, falling back to text-based " + f"decomposition: {e}" + ) + return self._run_with_fallback(question) + raise + tool_calls = result.additional_kwargs.get("tool_calls", None) sub_queries = [] if tool_calls: diff --git a/libs/ktem/ktem_tests/test_decompose_question.py b/libs/ktem/ktem_tests/test_decompose_question.py new file mode 100644 index 000000000..4c6c307c2 --- /dev/null +++ b/libs/ktem/ktem_tests/test_decompose_question.py @@ -0,0 +1,59 @@ +"""Tests for decompose_question functionality.""" + +from ktem.reasoning.prompt_optimization.decompose_question import ( + DecomposeQuestionPipeline, + _is_tool_not_supported_error, +) + + +class TestIsToolNotSupportedError: + """Test the _is_tool_not_supported_error helper function.""" + + def test_detects_ollama_error(self): + """Test detection of Ollama's 'does not support tools' error.""" + error = Exception( + "Error code: 400 - {'error': {'message': " + "'registry.ollama.ai/library/deepseek-r1:7b does not support tools'}}" + ) + assert _is_tool_not_supported_error(error) is True + + def test_detects_tool_use_not_supported(self): + """Test detection of 'tool use is not supported' error.""" + error = Exception("Tool use is not supported by this model") + assert _is_tool_not_supported_error(error) is True + + def test_detects_tools_are_not_supported(self): + """Test detection of 'tools are not supported' error.""" + error = Exception("Tools are not supported for this model type") + assert _is_tool_not_supported_error(error) is True + + def test_case_insensitive(self): + """Test that detection is case insensitive.""" + error = Exception("DOES NOT SUPPORT TOOLS") + assert _is_tool_not_supported_error(error) is True + + def test_other_errors_not_detected(self): + """Test that unrelated errors are not detected as tool support issues.""" + error = Exception("Connection timeout") + assert _is_tool_not_supported_error(error) is False + + error = Exception("Invalid API key") + assert _is_tool_not_supported_error(error) is False + + error = Exception("Rate limit exceeded") + assert _is_tool_not_supported_error(error) is False + + +class TestDecomposeQuestionPipelineFallback: + """Test the fallback behavior for models without tool support.""" + + def test_fallback_prompt_exists(self): + """Test that fallback prompt template is defined.""" + assert hasattr(DecomposeQuestionPipeline, "DECOMPOSE_FALLBACK_PROMPT_TEMPLATE") + assert ( + "JSON array" in DecomposeQuestionPipeline.DECOMPOSE_FALLBACK_PROMPT_TEMPLATE + ) + + def test_run_with_fallback_method_exists(self): + """Test that _run_with_fallback method is defined.""" + assert hasattr(DecomposeQuestionPipeline, "_run_with_fallback")