Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion core/quivr_core/llm/llm_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
azure_endpoint=azure_endpoint,
max_tokens=config.max_output_tokens,
temperature=config.temperature,
timeout=30.0,
max_retries=3,
)
elif config.supplier == DefaultModelSuppliers.ANTHROPIC:
assert config.llm_api_key, "Can't load model config"
Expand All @@ -247,7 +249,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
base_url=config.llm_base_url,
max_tokens_to_sample=config.max_output_tokens,
temperature=config.temperature,
timeout=None,
timeout=30.0,
max_retries=3,
stop=None,
)
elif config.supplier == DefaultModelSuppliers.OPENAI:
Expand All @@ -261,6 +264,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
temperature=config.temperature
if not config.model.startswith("o")
else None,
timeout=30.0,
max_retries=3,
)
elif config.supplier == DefaultModelSuppliers.MISTRAL:
_llm = ChatMistralAI(
Expand All @@ -270,6 +275,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
else None,
base_url=config.llm_base_url,
temperature=config.temperature,
timeout=30,
max_retries=3,
)
elif config.supplier == DefaultModelSuppliers.GEMINI:
_llm = ChatGoogleGenerativeAI(
Expand All @@ -280,6 +287,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
base_url=config.llm_base_url,
max_tokens=config.max_output_tokens,
temperature=config.temperature,
timeout=30,
max_retries=3,
)
elif config.supplier == DefaultModelSuppliers.GROQ:
_llm = ChatGroq(
Expand All @@ -290,6 +299,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
base_url=config.llm_base_url,
max_tokens=config.max_output_tokens,
temperature=config.temperature,
timeout=30,
max_retries=3,
)

else:
Expand All @@ -301,6 +312,8 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
base_url=config.llm_base_url,
max_completion_tokens=config.max_output_tokens,
temperature=config.temperature,
timeout=30.0,
max_retries=3,
)
instance = cls(llm=_llm, llm_config=config)
cls._cache[hashed_config] = instance
Expand Down
6 changes: 3 additions & 3 deletions core/quivr_core/rag/entities/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ def set_api_key(self, force_reset: bool = False):
self.llm_api_key = os.getenv(self.env_variable_name)

if not self.llm_api_key:
logger.warning(f"The API key for supplier '{self.supplier}' is not set. ")
logger.warning(
f"Please set the environment variable: '{self.env_variable_name}'. "
raise ValueError(
f"The API key for supplier '{self.supplier}' is not set. "
f"Please set the environment variable: '{self.env_variable_name}'."
)

def set_llm_model_config(self):
Expand Down
6 changes: 3 additions & 3 deletions core/quivr_core/rag/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,21 @@ def build_chain(self, files: str):

return loaded_memory | standalone_question | retrieved_documents | answer

def answer(
async def answer(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> ParsedRAGResponse:
"""
Answers a question using the QuivrQA RAG synchronously.
Answers a question using the QuivrQA RAG asynchronously.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain(concat_list_files)
raw_llm_response = conversational_qa_chain.invoke(
raw_llm_response = await conversational_qa_chain.ainvoke(
{
"question": question,
"chat_history": history,
Expand Down
124 changes: 102 additions & 22 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime
import logging
from collections import OrderedDict

from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from typing import (
Annotated,
Any,
Expand All @@ -22,6 +24,7 @@
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand Down Expand Up @@ -273,9 +276,32 @@ def __init__(
self.retrieval_config = retrieval_config
self.vector_store = vector_store
self.llm_endpoint = llm
self._lightweight_llm_endpoint: LLMEndpoint | None = None
self._structured_output_cache: OrderedDict[int, Any] = OrderedDict()
self._structured_output_cache_max_size: int = 128

self.graph = None

@property
def lightweight_llm(self) -> BaseChatModel:
"""Return a lightweight LLM for routing/rephrasing tasks (gpt-4o-mini).
Falls back to the main LLM if a lightweight endpoint cannot be created."""
if self._lightweight_llm_endpoint is not None:
return self._lightweight_llm_endpoint._llm
try:
from quivr_core.rag.entities.config import LLMEndpointConfig, DefaultModelSuppliers
lightweight_config = LLMEndpointConfig(
supplier=DefaultModelSuppliers.OPENAI,
model="gpt-4o-mini",
temperature=0.1,
max_output_tokens=4096,
)
self._lightweight_llm_endpoint = LLMEndpoint.from_config(lightweight_config)
return self._lightweight_llm_endpoint._llm
except Exception:
logger.warning("Could not create lightweight LLM, falling back to main LLM")
return self.llm_endpoint._llm

def get_reranker(self, **kwargs):
# Extract the reranker configuration from self
config = self.retrieval_config.reranker_config
Expand Down Expand Up @@ -313,7 +339,7 @@ def get_retriever(self, **kwargs):

return retriever

def routing(self, state: AgentState) -> List[Send]:
async def routing(self, state: AgentState) -> List[Send]:
"""
The routing function for the RAG model.

Expand All @@ -334,13 +360,13 @@ def routing(self, state: AgentState) -> List[Send]:
structured_llm = self.llm_endpoint._llm.with_structured_output(
SplittedInput, method="json_schema"
)
response = structured_llm.invoke(msg)
response = await structured_llm.ainvoke(msg)

except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(
SplittedInput
)
response = structured_llm.invoke(msg)
response = await structured_llm.ainvoke(msg)

send_list: List[Send] = []

Expand Down Expand Up @@ -500,20 +526,25 @@ async def rewrite(self, state: AgentState) -> AgentState:
task=tasks(task_id).definition,
)

model = self.llm_endpoint._llm
# Asynchronously invoke the model for each question
model = self.lightweight_llm
# Asynchronously invoke the lightweight model for each question
async_jobs.append((model.ainvoke(msg), task_id))

# Gather all the responses asynchronously
responses = (
await asyncio.gather(*(jobs[0] for jobs in async_jobs))
await asyncio.gather(
*(jobs[0] for jobs in async_jobs), return_exceptions=True
)
if async_jobs
else []
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

# Replace each question with its condensed version
for response, task_id in zip(responses, task_ids, strict=False):
if isinstance(response, Exception):
logger.error(f"Task rephrasing failed for {task_id}: {response}")
continue
tasks.set_definition(task_id, response.content)

return {**state, "tasks": tasks}
Expand Down Expand Up @@ -559,17 +590,22 @@ async def tool_routing(self, state: AgentState):

msg = custom_prompts[TemplatePromptName.TOOL_ROUTING_PROMPT].format(**input)
async_jobs.append(
(self.ainvoke_structured_output(msg, TasksCompletion), task_id)
(self.ainvoke_structured_output(msg, TasksCompletion, use_lightweight=True), task_id)
)

responses: List[TasksCompletion] = (
await asyncio.gather(*(jobs[0] for jobs in async_jobs))
await asyncio.gather(
*(jobs[0] for jobs in async_jobs), return_exceptions=True
)
if async_jobs
else []
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids, strict=False):
if isinstance(response, Exception):
logger.error(f"Tool routing failed for {task_id}: {response}")
continue
tasks.set_completion(task_id, response.is_task_completable)
if not response.is_task_completable and response.tool:
tasks.set_tool(task_id, response.tool)
Expand Down Expand Up @@ -920,7 +956,7 @@ def bind_tools_to_llm(self, node_name: str):
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
return self.llm_endpoint._llm

def generate_zendesk_rag(self, state: AgentState) -> AgentState:
async def generate_zendesk_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs: List[Document] = tasks.docs if tasks else []
messages = state["messages"]
Expand Down Expand Up @@ -949,23 +985,23 @@ def generate_zendesk_rag(self, state: AgentState) -> AgentState:
msg = prompt_template.format_prompt(**inputs)
llm = self.bind_tools_to_llm(self.generate_zendesk_rag.__name__)

response = llm.invoke(msg)
response = await llm.ainvoke(msg)

return {**state, "messages": [response]}

def generate_rag(self, state: AgentState) -> AgentState:
async def generate_rag(self, state: AgentState) -> AgentState:
tasks = state["tasks"]
docs = tasks.docs if tasks else []
inputs = self._build_rag_prompt_inputs(state, docs)
prompt = custom_prompts[TemplatePromptName.RAG_ANSWER_PROMPT]
state, inputs = self.reduce_rag_context(state, inputs, prompt)
msg = prompt.format(**inputs)
llm = self.bind_tools_to_llm(self.generate_rag.__name__)
response = llm.invoke(msg)
response = await llm.ainvoke(msg)

return {**state, "messages": [response]}

def generate_chat_llm(self, state: AgentState) -> AgentState:
async def generate_chat_llm(self, state: AgentState) -> AgentState:
"""
Generate answer

Expand Down Expand Up @@ -1014,10 +1050,10 @@ def generate_chat_llm(self, state: AgentState) -> AgentState:
]
)
# Run
chat_llm_prompt = CHAT_LLM_PROMPT.invoke(
chat_llm_prompt = await CHAT_LLM_PROMPT.ainvoke(
{"chat_history": final_inputs["chat_history"]}
)
response = llm.invoke(chat_llm_prompt)
response = await llm.ainvoke(chat_llm_prompt)
return {**state, "messages": [response]}

def build_chain(self):
Expand Down Expand Up @@ -1172,29 +1208,73 @@ def _extract_node_name(self, event: StreamEvent) -> str:
return node.name
return ""

def _get_cache_key(self, prompt: str, output_class: Type[BaseModel]) -> int:
return hash((prompt, output_class.__name__))

def _cache_get(self, key: int) -> Any | None:
if key in self._structured_output_cache:
self._structured_output_cache.move_to_end(key)
return self._structured_output_cache[key]
return None

def _cache_put(self, key: int, value: Any) -> None:
self._structured_output_cache[key] = value
self._structured_output_cache.move_to_end(key)
while len(self._structured_output_cache) > self._structured_output_cache_max_size:
self._structured_output_cache.popitem(last=False)

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((openai.APITimeoutError, openai.APIConnectionError, openai.RateLimitError, openai.InternalServerError)),
reraise=True,
)
async def ainvoke_structured_output(
self, prompt: str, output_class: Type[BaseModel]
self, prompt: str, output_class: Type[BaseModel], use_lightweight: bool = False
) -> Any:
cache_key = self._get_cache_key(prompt, output_class)
cached = self._cache_get(cache_key)
if cached is not None:
return cached

llm = self.lightweight_llm if use_lightweight else self.llm_endpoint._llm
try:
structured_llm = self.llm_endpoint._llm.with_structured_output(
structured_llm = llm.with_structured_output(
output_class, method="json_schema"
)
return await structured_llm.ainvoke(prompt)
result = await structured_llm.ainvoke(prompt)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
return await structured_llm.ainvoke(prompt)
structured_llm = llm.with_structured_output(output_class)
result = await structured_llm.ainvoke(prompt)

self._cache_put(cache_key, result)
return result

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((openai.APITimeoutError, openai.APIConnectionError, openai.RateLimitError, openai.InternalServerError)),
reraise=True,
)
def invoke_structured_output(
self, prompt: str, output_class: Type[BaseModel]
) -> Any:
cache_key = self._get_cache_key(prompt, output_class)
cached = self._cache_get(cache_key)
if cached is not None:
return cached

try:
structured_llm = self.llm_endpoint._llm.with_structured_output(
output_class, method="json_schema"
)
return structured_llm.invoke(prompt)
result = structured_llm.invoke(prompt)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
return structured_llm.invoke(prompt)
result = structured_llm.invoke(prompt)

self._cache_put(cache_key, result)
return result

def _build_rag_prompt_inputs(
self, state: AgentState, docs: List[Document] | None
Expand Down