Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
import re

from ktem.llms.manager import llms
from ktem.reasoning.prompt_optimization.rewrite_question import RewriteQuestionPipeline
Expand All @@ -10,6 +12,16 @@
logger = logging.getLogger(__name__)


def _is_tool_not_supported_error(error: Exception) -> bool:
"""Check if the error indicates the model does not support tools."""
error_msg = str(error).lower()
return (
"does not support tools" in error_msg
or "tool use is not supported" in error_msg
or "tools are not supported" in error_msg
)


class SubQuery(BaseModel):
"""Search over a database of insurance rulebooks or financial reports"""

Expand Down Expand Up @@ -40,6 +52,20 @@ class DecomposeQuestionPipeline(RewriteQuestionPipeline):
"If there are acronyms or words you are not familiar with, "
"do not try to rephrase them."
)
# Fallback prompt for models that don't support tools
DECOMPOSE_FALLBACK_PROMPT_TEMPLATE = (
"You are an expert at converting user complex questions into sub questions. "
"Given a user question, break it down into the most specific sub"
" questions you can (at most 3) "
"which will help you answer the original question. "
"Each sub question should be about a single concept/fact/idea. "
"If there are acronyms or words you are not familiar with, "
"do not try to rephrase them.\n\n"
"Output your sub-questions as a JSON array of objects, where each object has "
'a "sub_query" field. Example:\n'
'[{"sub_query": "What is X?"}, {"sub_query": "How does Y work?"}]\n\n'
"Only output the JSON array, no other text."
)
prompt_template: str = DECOMPOSE_SYSTEM_PROMPT_TEMPLATE

def create_prompt(self, question):
Expand All @@ -62,9 +88,59 @@ def create_prompt(self, question):

return messages, llm_kwargs

def _run_with_fallback(self, question: str) -> list:
"""Fallback method for models that don't support tools.

Uses plain text prompting to decompose the question.
"""
messages = [
SystemMessage(content=self.DECOMPOSE_FALLBACK_PROMPT_TEMPLATE),
HumanMessage(content=question),
]

result = self.llm(messages)
text = result.text.strip()

# Try to parse the response as JSON
sub_queries = []
try:
# Try to extract JSON array from the response
# Handle cases where model adds extra text around the JSON
json_match = re.search(r"\[.*\]", text, re.DOTALL)
if json_match:
parsed = json.loads(json_match.group())
for item in parsed:
if isinstance(item, dict) and "sub_query" in item:
sub_queries.append(Document(content=item["sub_query"]))
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse fallback response as JSON: {e}")
# If JSON parsing fails, try to extract questions from the text
# by looking for numbered items or question marks
lines = text.split("\n")
for line in lines:
line = line.strip()
# Remove common prefixes like "1.", "- ", "* "
line = re.sub(r"^[\d]+[.)\]]\s*", "", line)
line = re.sub(r"^[-*•]\s*", "", line)
if line and ("?" in line or len(line) > 10):
sub_queries.append(Document(content=line))

return sub_queries[:3] # Limit to 3 sub-questions

def run(self, question: str) -> list: # type: ignore
messages, llm_kwargs = self.create_prompt(question)
result = self.llm(messages, **llm_kwargs)

try:
result = self.llm(messages, **llm_kwargs)
except Exception as e:
if _is_tool_not_supported_error(e):
logger.warning(
f"Model does not support tools, falling back to text-based "
f"decomposition: {e}"
)
return self._run_with_fallback(question)
raise

tool_calls = result.additional_kwargs.get("tool_calls", None)
sub_queries = []
if tool_calls:
Expand Down
59 changes: 59 additions & 0 deletions libs/ktem/ktem_tests/test_decompose_question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for decompose_question functionality."""

from ktem.reasoning.prompt_optimization.decompose_question import (
DecomposeQuestionPipeline,
_is_tool_not_supported_error,
)


class TestIsToolNotSupportedError:
"""Test the _is_tool_not_supported_error helper function."""

def test_detects_ollama_error(self):
"""Test detection of Ollama's 'does not support tools' error."""
error = Exception(
"Error code: 400 - {'error': {'message': "
"'registry.ollama.ai/library/deepseek-r1:7b does not support tools'}}"
)
assert _is_tool_not_supported_error(error) is True

def test_detects_tool_use_not_supported(self):
"""Test detection of 'tool use is not supported' error."""
error = Exception("Tool use is not supported by this model")
assert _is_tool_not_supported_error(error) is True

def test_detects_tools_are_not_supported(self):
"""Test detection of 'tools are not supported' error."""
error = Exception("Tools are not supported for this model type")
assert _is_tool_not_supported_error(error) is True

def test_case_insensitive(self):
"""Test that detection is case insensitive."""
error = Exception("DOES NOT SUPPORT TOOLS")
assert _is_tool_not_supported_error(error) is True

def test_other_errors_not_detected(self):
"""Test that unrelated errors are not detected as tool support issues."""
error = Exception("Connection timeout")
assert _is_tool_not_supported_error(error) is False

error = Exception("Invalid API key")
assert _is_tool_not_supported_error(error) is False

error = Exception("Rate limit exceeded")
assert _is_tool_not_supported_error(error) is False


class TestDecomposeQuestionPipelineFallback:
"""Test the fallback behavior for models without tool support."""

def test_fallback_prompt_exists(self):
"""Test that fallback prompt template is defined."""
assert hasattr(DecomposeQuestionPipeline, "DECOMPOSE_FALLBACK_PROMPT_TEMPLATE")
assert (
"JSON array" in DecomposeQuestionPipeline.DECOMPOSE_FALLBACK_PROMPT_TEMPLATE
)

def test_run_with_fallback_method_exists(self):
"""Test that _run_with_fallback method is defined."""
assert hasattr(DecomposeQuestionPipeline, "_run_with_fallback")