diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 3ebb2e62b..1bcd37385 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -58,6 +58,8 @@ # Default completion tokens # Reduce premature termination, especially when streaming structured responses MAX_COMPLETION_TOKENS = 16000 +# Groq API has a lower max_completion_tokens limit +GROQ_MAX_COMPLETION_TOKENS = 8192 def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str: @@ -157,6 +159,11 @@ def completion_with_backoff( add_qwen_no_think_tag(formatted_messages) elif is_groq_api(api_base_url): model_kwargs["service_tier"] = "auto" + # Groq API has a lower max_completion_tokens limit + model_kwargs["max_completion_tokens"] = min( + model_kwargs.get("max_completion_tokens", GROQ_MAX_COMPLETION_TOKENS), + GROQ_MAX_COMPLETION_TOKENS, + ) read_timeout = 300 if is_local_api(api_base_url) else 60 if os.getenv("KHOJ_LLM_SEED"): @@ -359,6 +366,11 @@ async def chat_completion_with_backoff( add_qwen_no_think_tag(formatted_messages) elif is_groq_api(api_base_url): model_kwargs["service_tier"] = "auto" + # Groq API has a lower max_completion_tokens limit + model_kwargs["max_completion_tokens"] = min( + model_kwargs.get("max_completion_tokens", GROQ_MAX_COMPLETION_TOKENS), + GROQ_MAX_COMPLETION_TOKENS, + ) read_timeout = 300 if is_local_api(api_base_url) else 60 if os.getenv("KHOJ_LLM_SEED"): diff --git a/tests/test_openai_utils.py b/tests/test_openai_utils.py new file mode 100644 index 000000000..dc073c3c9 --- /dev/null +++ b/tests/test_openai_utils.py @@ -0,0 +1,51 @@ +import pytest + + +@pytest.fixture(autouse=True) +def _no_db(monkeypatch): + """Skip database access for these pure unit tests.""" + pass + + +class TestIsGroqApi: + def test_groq_api_url(self): + from khoj.processor.conversation.openai.utils import is_groq_api + + assert is_groq_api("https://api.groq.com/openai/v1") is True + + def test_groq_api_base_url(self): + from khoj.processor.conversation.openai.utils import is_groq_api + + assert is_groq_api("https://api.groq.com") is True + + def test_non_groq_api_url(self): + from khoj.processor.conversation.openai.utils import is_groq_api + + assert is_groq_api("https://api.openai.com/v1") is False + + def test_none_url(self): + from khoj.processor.conversation.openai.utils import is_groq_api + + assert is_groq_api(None) is False + + def test_empty_url(self): + from khoj.processor.conversation.openai.utils import is_groq_api + + assert is_groq_api("") is False + + +class TestMaxCompletionTokensConstants: + def test_groq_max_tokens_less_than_default(self): + """Groq API max_completion_tokens should be less than the default""" + from khoj.processor.conversation.openai.utils import ( + GROQ_MAX_COMPLETION_TOKENS, + MAX_COMPLETION_TOKENS, + ) + + assert GROQ_MAX_COMPLETION_TOKENS < MAX_COMPLETION_TOKENS + + def test_groq_max_tokens_is_8192(self): + """Groq API max_completion_tokens should be 8192 based on API limits""" + from khoj.processor.conversation.openai.utils import GROQ_MAX_COMPLETION_TOKENS + + assert GROQ_MAX_COMPLETION_TOKENS == 8192