Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/khoj/processor/conversation/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
# Default completion tokens
# Reduce premature termination, especially when streaming structured responses
MAX_COMPLETION_TOKENS = 16000
# Groq API has a lower max_completion_tokens limit
GROQ_MAX_COMPLETION_TOKENS = 8192


def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str:
Expand Down Expand Up @@ -157,6 +159,11 @@ def completion_with_backoff(
add_qwen_no_think_tag(formatted_messages)
elif is_groq_api(api_base_url):
model_kwargs["service_tier"] = "auto"
# Groq API has a lower max_completion_tokens limit
model_kwargs["max_completion_tokens"] = min(
model_kwargs.get("max_completion_tokens", GROQ_MAX_COMPLETION_TOKENS),
GROQ_MAX_COMPLETION_TOKENS,
)

read_timeout = 300 if is_local_api(api_base_url) else 60
if os.getenv("KHOJ_LLM_SEED"):
Expand Down Expand Up @@ -359,6 +366,11 @@ async def chat_completion_with_backoff(
add_qwen_no_think_tag(formatted_messages)
elif is_groq_api(api_base_url):
model_kwargs["service_tier"] = "auto"
# Groq API has a lower max_completion_tokens limit
model_kwargs["max_completion_tokens"] = min(
model_kwargs.get("max_completion_tokens", GROQ_MAX_COMPLETION_TOKENS),
GROQ_MAX_COMPLETION_TOKENS,
)

read_timeout = 300 if is_local_api(api_base_url) else 60
if os.getenv("KHOJ_LLM_SEED"):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_openai_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest


@pytest.fixture(autouse=True)
def _no_db(monkeypatch):
"""Skip database access for these pure unit tests."""
pass


class TestIsGroqApi:
def test_groq_api_url(self):
from khoj.processor.conversation.openai.utils import is_groq_api

assert is_groq_api("https://api.groq.com/openai/v1") is True

def test_groq_api_base_url(self):
from khoj.processor.conversation.openai.utils import is_groq_api

assert is_groq_api("https://api.groq.com") is True

def test_non_groq_api_url(self):
from khoj.processor.conversation.openai.utils import is_groq_api

assert is_groq_api("https://api.openai.com/v1") is False

def test_none_url(self):
from khoj.processor.conversation.openai.utils import is_groq_api

assert is_groq_api(None) is False

def test_empty_url(self):
from khoj.processor.conversation.openai.utils import is_groq_api

assert is_groq_api("") is False


class TestMaxCompletionTokensConstants:
def test_groq_max_tokens_less_than_default(self):
"""Groq API max_completion_tokens should be less than the default"""
from khoj.processor.conversation.openai.utils import (
GROQ_MAX_COMPLETION_TOKENS,
MAX_COMPLETION_TOKENS,
)

assert GROQ_MAX_COMPLETION_TOKENS < MAX_COMPLETION_TOKENS

def test_groq_max_tokens_is_8192(self):
"""Groq API max_completion_tokens should be 8192 based on API limits"""
from khoj.processor.conversation.openai.utils import GROQ_MAX_COMPLETION_TOKENS

assert GROQ_MAX_COMPLETION_TOKENS == 8192