|
17 | 17 | from fai_ai_core.retrieval.filters import QueryFilters |
18 | 18 | from fai_ai_core.retrieval.interface import RetrievalQuery, RetrievalStrategy |
19 | 19 | from fai_ai_core.tools.documentation_search import create_documentation_search_tool |
| 20 | +from fai_ai_core.tools.models import Tool, ToolDefinition, ToolParameter |
20 | 21 | from slack_sdk.web.async_client import AsyncWebClient |
21 | 22 | from sqlalchemy import select |
22 | 23 |
|
|
45 | 46 | sync_slack_context_db_to_tpuf, |
46 | 47 | ) |
47 | 48 |
|
48 | | -SAVE_SLACK_CONTEXT_TOOL = { |
49 | | - "name": "save_slack_context", |
50 | | - "description": "Save a question and ideal response pair to the knowledge base for future reference.", |
51 | | - "input_schema": { |
52 | | - "type": "object", |
53 | | - "properties": { |
54 | | - "question": { |
55 | | - "type": "string", |
56 | | - "description": "The question that was asked or should be asked in the future", |
57 | | - }, |
58 | | - "ideal_response": { |
59 | | - "type": "string", |
60 | | - "description": "The ideal response to give when this question is asked", |
61 | | - }, |
62 | | - }, |
63 | | - "required": ["question", "ideal_response"], |
64 | | - }, |
65 | | -} |
| 49 | + |
| 50 | +@dataclass |
| 51 | +class SlackContextCapture: |
| 52 | + data: dict[str, str] | None = None |
| 53 | + |
| 54 | + |
| 55 | +def create_save_slack_context_tool(capture: SlackContextCapture) -> Tool: |
| 56 | + async def execute(arguments: dict[str, str]) -> str: |
| 57 | + capture.data = { |
| 58 | + "question": arguments.get("question", ""), |
| 59 | + "ideal_response": arguments.get("ideal_response", ""), |
| 60 | + } |
| 61 | + return "Context saved successfully." |
| 62 | + |
| 63 | + return Tool( |
| 64 | + definition=ToolDefinition( |
| 65 | + name="save_slack_context", |
| 66 | + description="Save a question and ideal response pair to the knowledge base for future reference.", |
| 67 | + parameters=[ |
| 68 | + ToolParameter( |
| 69 | + name="question", |
| 70 | + type="string", |
| 71 | + description="The question that was asked or should be asked in the future", |
| 72 | + required=True, |
| 73 | + ), |
| 74 | + ToolParameter( |
| 75 | + name="ideal_response", |
| 76 | + type="string", |
| 77 | + description="The ideal response to give when this question is asked", |
| 78 | + required=True, |
| 79 | + ), |
| 80 | + ], |
| 81 | + ), |
| 82 | + execute=execute, |
| 83 | + max_calls=1, |
| 84 | + ) |
66 | 85 |
|
67 | 86 |
|
68 | 87 | def _create_delimited_role_combinations(roleset: list[str], delimiter: str = "&") -> list[str]: |
@@ -303,20 +322,13 @@ async def _get_slack_index_response( |
303 | 322 | role = MessageRole.ASSISTANT if msg["role"] == "assistant" else MessageRole.USER |
304 | 323 | llm_messages.append(LLMMessage(role=role, content=msg["content"])) |
305 | 324 |
|
306 | | - provider = get_llm_provider(model=model, temperature=0.0, max_tokens=2000) |
307 | | - response = await provider.generate(llm_messages, tools=[SAVE_SLACK_CONTEXT_TOOL]) |
| 325 | + capture = SlackContextCapture() |
| 326 | + save_context_tool = create_save_slack_context_tool(capture) |
308 | 327 |
|
309 | | - context_data = None |
310 | | - if response.tool_calls: |
311 | | - for tool_call in response.tool_calls: |
312 | | - if tool_call.name == "save_slack_context": |
313 | | - context_data = { |
314 | | - "question": tool_call.arguments.get("question", ""), |
315 | | - "ideal_response": tool_call.arguments.get("ideal_response", ""), |
316 | | - } |
317 | | - break |
| 328 | + provider = get_llm_provider(model=model, temperature=0.0, max_tokens=2000) |
| 329 | + response = await provider.generate(llm_messages, tools=[save_context_tool]) |
318 | 330 |
|
319 | | - return response.content, context_data |
| 331 | + return response.content, capture.data |
320 | 332 |
|
321 | 333 |
|
322 | 334 | async def process_message( |
@@ -349,7 +361,7 @@ async def process_message( |
349 | 361 | roles_with_everyone.append("everyone") |
350 | 362 | exploded_roles = _create_delimited_role_combinations(roles_with_everyone) |
351 | 363 | LOGGER.info(f"Using exploded roles for filtering: {exploded_roles}") |
352 | | - filters = QueryFilters(roles=exploded_roles) |
| 364 | + filters = QueryFilters(exploded_roles=exploded_roles) |
353 | 365 |
|
354 | 366 | retriever = get_retriever() |
355 | 367 | retrieval_query = RetrievalQuery( |
|
0 commit comments