Skip to content

Commit f4dbcab

Browse files
tsbhanguclaude
andauthored
fix(ask-fern): fix type errors not originally fixed in fai integration with fai_ai_core (#5799)
Co-authored-by: Claude <[email protected]>
1 parent 5732aca commit f4dbcab

File tree

3 files changed

+47
-33
lines changed

3 files changed

+47
-33
lines changed

servers/fai/mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
python_executable = .venv/bin/python
33
plugins = pydantic.mypy
44
files = src
5-
mypy_path = .
5+
mypy_path = .:../python-libs/fai_ai_core
66
explicit_package_bases = True
77
ignore_missing_imports = True
88
no_site_packages = True

servers/fai/src/fai/routes/chat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def post_chat_completion(
7979
query=sq,
8080
domain=domain,
8181
strategy=RetrievalStrategy.HYBRID,
82+
top_k=TOP_K,
8283
)
8384
for sq in sub_queries
8485
]
@@ -90,8 +91,9 @@ async def post_chat_completion(
9091
query=query_content,
9192
domain=domain,
9293
strategy=RetrievalStrategy.HYBRID,
94+
top_k=TOP_K,
9395
)
94-
result = await retriever.retrieve(retrieval_query, top_k=TOP_K)
96+
result = await retriever.retrieve(retrieval_query)
9597
retrieved_documents = result.documents
9698

9799
model: ModelId = DEFAULT_MODEL

servers/fai/src/fai/utils/slack/message_handler.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from fai_ai_core.retrieval.filters import QueryFilters
1818
from fai_ai_core.retrieval.interface import RetrievalQuery, RetrievalStrategy
1919
from fai_ai_core.tools.documentation_search import create_documentation_search_tool
20+
from fai_ai_core.tools.models import Tool, ToolDefinition, ToolParameter
2021
from slack_sdk.web.async_client import AsyncWebClient
2122
from sqlalchemy import select
2223

@@ -45,24 +46,42 @@
4546
sync_slack_context_db_to_tpuf,
4647
)
4748

48-
SAVE_SLACK_CONTEXT_TOOL = {
49-
"name": "save_slack_context",
50-
"description": "Save a question and ideal response pair to the knowledge base for future reference.",
51-
"input_schema": {
52-
"type": "object",
53-
"properties": {
54-
"question": {
55-
"type": "string",
56-
"description": "The question that was asked or should be asked in the future",
57-
},
58-
"ideal_response": {
59-
"type": "string",
60-
"description": "The ideal response to give when this question is asked",
61-
},
62-
},
63-
"required": ["question", "ideal_response"],
64-
},
65-
}
49+
50+
@dataclass
51+
class SlackContextCapture:
52+
data: dict[str, str] | None = None
53+
54+
55+
def create_save_slack_context_tool(capture: SlackContextCapture) -> Tool:
56+
async def execute(arguments: dict[str, str]) -> str:
57+
capture.data = {
58+
"question": arguments.get("question", ""),
59+
"ideal_response": arguments.get("ideal_response", ""),
60+
}
61+
return "Context saved successfully."
62+
63+
return Tool(
64+
definition=ToolDefinition(
65+
name="save_slack_context",
66+
description="Save a question and ideal response pair to the knowledge base for future reference.",
67+
parameters=[
68+
ToolParameter(
69+
name="question",
70+
type="string",
71+
description="The question that was asked or should be asked in the future",
72+
required=True,
73+
),
74+
ToolParameter(
75+
name="ideal_response",
76+
type="string",
77+
description="The ideal response to give when this question is asked",
78+
required=True,
79+
),
80+
],
81+
),
82+
execute=execute,
83+
max_calls=1,
84+
)
6685

6786

6887
def _create_delimited_role_combinations(roleset: list[str], delimiter: str = "&") -> list[str]:
@@ -303,20 +322,13 @@ async def _get_slack_index_response(
303322
role = MessageRole.ASSISTANT if msg["role"] == "assistant" else MessageRole.USER
304323
llm_messages.append(LLMMessage(role=role, content=msg["content"]))
305324

306-
provider = get_llm_provider(model=model, temperature=0.0, max_tokens=2000)
307-
response = await provider.generate(llm_messages, tools=[SAVE_SLACK_CONTEXT_TOOL])
325+
capture = SlackContextCapture()
326+
save_context_tool = create_save_slack_context_tool(capture)
308327

309-
context_data = None
310-
if response.tool_calls:
311-
for tool_call in response.tool_calls:
312-
if tool_call.name == "save_slack_context":
313-
context_data = {
314-
"question": tool_call.arguments.get("question", ""),
315-
"ideal_response": tool_call.arguments.get("ideal_response", ""),
316-
}
317-
break
328+
provider = get_llm_provider(model=model, temperature=0.0, max_tokens=2000)
329+
response = await provider.generate(llm_messages, tools=[save_context_tool])
318330

319-
return response.content, context_data
331+
return response.content, capture.data
320332

321333

322334
async def process_message(
@@ -349,7 +361,7 @@ async def process_message(
349361
roles_with_everyone.append("everyone")
350362
exploded_roles = _create_delimited_role_combinations(roles_with_everyone)
351363
LOGGER.info(f"Using exploded roles for filtering: {exploded_roles}")
352-
filters = QueryFilters(roles=exploded_roles)
364+
filters = QueryFilters(exploded_roles=exploded_roles)
353365

354366
retriever = get_retriever()
355367
retrieval_query = RetrievalQuery(

0 commit comments

Comments
 (0)