Skip to content

Commit fb99868

Browse files
authored
fix: Agent uses the first configured vector_db_id when documents are provided (#1276)
# What does this PR do? The agent API allows to query multiple DBs using the `vector_db_ids` argument of the `rag` tool: ```py toolgroups=[ { "name": "builtin::rag", "args": {"vector_db_ids": [vector_db_id]}, } ], ``` This means that multiple DBs can be used to compose an aggregated context by executing the query on each of them. When documents are passed to the next agent turn, there is no explicit way to configure the vector DB where the embeddings will be ingested. In such cases, we can assume that: - if any `vector_db_ids` is given, we use the first one (it probably makes sense to assume that it's the only one in the list, otherwise we should loop on all the given DBs to have a consistent ingestion) - if no `vector_db_ids` is given, we can use the current logic to generate a default DB using the default provider. If multiple providers are defined, the API will fail as expected: the user has to provide details on where to ingest the documents. (Closes #1270) ## Test Plan The issue description details how to replicate the problem. [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
1 parent 78962be commit fb99868

File tree

3 files changed

+63
-51
lines changed

3 files changed

+63
-51
lines changed

docs/source/building_applications/rag.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ response = agent.create_turn(
122122
],
123123
documents=[
124124
{
125-
"content": "https://raw.githubusercontent.com/example/doc.rst",
125+
"content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
126126
"mime_type": "text/plain",
127127
}
128128
],

llama_stack/distribution/routers/routing_tables.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,14 @@ async def register_vector_db(
309309
if provider_vector_db_id is None:
310310
provider_vector_db_id = vector_db_id
311311
if provider_id is None:
312-
# If provider_id not specified, use the only provider if it supports this shield type
313-
if len(self.impls_by_provider_id) == 1:
312+
if len(self.impls_by_provider_id) > 0:
314313
provider_id = list(self.impls_by_provider_id.keys())[0]
314+
if len(self.impls_by_provider_id) > 1:
315+
logger.warning(
316+
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
317+
)
315318
else:
316-
raise ValueError(
317-
"No provider specified and multiple providers available. Please specify a provider_id."
318-
)
319+
raise ValueError("No provider available. Please configure a vector_io provider.")
319320
model = await self.get_object_by_identifier("model", embedding_model)
320321
if model is None:
321322
raise ValueError(f"Model {embedding_model} not found")

llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,35 @@
1616
AgentTurnResponseTurnCompletePayload,
1717
StepType,
1818
)
19-
from llama_stack.apis.common.content_types import URL
19+
from llama_stack.apis.common.content_types import URL, TextDelta
2020
from llama_stack.apis.inference import (
2121
ChatCompletionResponse,
2222
ChatCompletionResponseEvent,
23+
ChatCompletionResponseEventType,
2324
ChatCompletionResponseStreamChunk,
2425
CompletionMessage,
2526
LogProbConfig,
2627
Message,
2728
ResponseFormat,
2829
SamplingParams,
2930
ToolChoice,
31+
ToolConfig,
3032
ToolDefinition,
3133
ToolPromptFormat,
3234
UserMessage,
3335
)
3436
from llama_stack.apis.safety import RunShieldResponse
3537
from llama_stack.apis.tools import (
38+
ListToolGroupsResponse,
39+
ListToolsResponse,
3640
Tool,
3741
ToolDef,
3842
ToolGroup,
3943
ToolHost,
4044
ToolInvocationResult,
4145
)
4246
from llama_stack.apis.vector_io import QueryChunksResponse
43-
from llama_stack.models.llama.datatypes import BuiltinTool
47+
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
4448
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
4549
MEMORY_QUERY_TOOL,
4650
)
@@ -54,36 +58,37 @@
5458
class MockInferenceAPI:
5559
async def chat_completion(
5660
self,
57-
model: str,
61+
model_id: str,
5862
messages: List[Message],
5963
sampling_params: Optional[SamplingParams] = SamplingParams(),
60-
response_format: Optional[ResponseFormat] = None,
6164
tools: Optional[List[ToolDefinition]] = None,
6265
tool_choice: Optional[ToolChoice] = None,
6366
tool_prompt_format: Optional[ToolPromptFormat] = None,
67+
response_format: Optional[ResponseFormat] = None,
6468
stream: Optional[bool] = False,
6569
logprobs: Optional[LogProbConfig] = None,
70+
tool_config: Optional[ToolConfig] = None,
6671
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
6772
async def stream_response():
6873
yield ChatCompletionResponseStreamChunk(
6974
event=ChatCompletionResponseEvent(
70-
event_type="start",
71-
delta="",
75+
event_type=ChatCompletionResponseEventType.start,
76+
delta=TextDelta(text=""),
7277
)
7378
)
7479

7580
yield ChatCompletionResponseStreamChunk(
7681
event=ChatCompletionResponseEvent(
77-
event_type="progress",
78-
delta="AI is a fascinating field...",
82+
event_type=ChatCompletionResponseEventType.progress,
83+
delta=TextDelta(text="AI is a fascinating field..."),
7984
)
8085
)
8186

8287
yield ChatCompletionResponseStreamChunk(
8388
event=ChatCompletionResponseEvent(
84-
event_type="complete",
85-
delta="",
86-
stop_reason="end_of_turn",
89+
event_type=ChatCompletionResponseEventType.complete,
90+
delta=TextDelta(text=""),
91+
stop_reason=StopReason.end_of_turn,
8792
)
8893
)
8994

@@ -133,35 +138,39 @@ async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
133138
provider_resource_id=toolgroup_id,
134139
)
135140

136-
async def list_tool_groups(self) -> List[ToolGroup]:
137-
return []
138-
139-
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
140-
if tool_group_id == MEMORY_TOOLGROUP:
141-
return [
142-
Tool(
143-
identifier=MEMORY_QUERY_TOOL,
144-
provider_resource_id=MEMORY_QUERY_TOOL,
145-
toolgroup_id=MEMORY_TOOLGROUP,
146-
tool_host=ToolHost.client,
147-
description="Mock tool",
148-
provider_id="builtin::rag",
149-
parameters=[],
150-
)
151-
]
152-
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
153-
return [
154-
Tool(
155-
identifier="code_interpreter",
156-
provider_resource_id="code_interpreter",
157-
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
158-
tool_host=ToolHost.client,
159-
description="Mock tool",
160-
provider_id="builtin::code_interpreter",
161-
parameters=[],
162-
)
163-
]
164-
return []
141+
async def list_tool_groups(self) -> ListToolGroupsResponse:
142+
return ListToolGroupsResponse(data=[])
143+
144+
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
145+
if toolgroup_id == MEMORY_TOOLGROUP:
146+
return ListToolsResponse(
147+
data=[
148+
Tool(
149+
identifier=MEMORY_QUERY_TOOL,
150+
provider_resource_id=MEMORY_QUERY_TOOL,
151+
toolgroup_id=MEMORY_TOOLGROUP,
152+
tool_host=ToolHost.client,
153+
description="Mock tool",
154+
provider_id="builtin::rag",
155+
parameters=[],
156+
)
157+
]
158+
)
159+
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
160+
return ListToolsResponse(
161+
data=[
162+
Tool(
163+
identifier="code_interpreter",
164+
provider_resource_id="code_interpreter",
165+
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
166+
tool_host=ToolHost.client,
167+
description="Mock tool",
168+
provider_id="builtin::code_interpreter",
169+
parameters=[],
170+
)
171+
]
172+
)
173+
return ListToolsResponse(data=[])
165174

166175
async def get_tool(self, tool_name: str) -> Tool:
167176
return Tool(
@@ -174,7 +183,7 @@ async def get_tool(self, tool_name: str) -> Tool:
174183
parameters=[],
175184
)
176185

177-
async def unregister_tool_group(self, tool_group_id: str) -> None:
186+
async def unregister_tool_group(self, toolgroup_id: str) -> None:
178187
pass
179188

180189

@@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
382391
chat_agent = await impl.get_agent(response.agent_id)
383392

384393
tool_defs, _ = await chat_agent._get_tool_defs()
394+
tool_defs_names = [t.tool_name for t in tool_defs]
385395
if expected_memory:
386-
assert MEMORY_QUERY_TOOL in tool_defs
396+
assert MEMORY_QUERY_TOOL in tool_defs_names
387397
if expected_code_interpreter:
388-
assert BuiltinTool.code_interpreter in tool_defs
398+
assert BuiltinTool.code_interpreter in tool_defs_names
389399
if expected_memory and expected_code_interpreter:
390400
# override the tools for turn
391401
new_tool_defs, _ = await chat_agent._get_tool_defs(
@@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
396406
)
397407
]
398408
)
399-
assert MEMORY_QUERY_TOOL in new_tool_defs
400-
assert BuiltinTool.code_interpreter not in new_tool_defs
409+
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
410+
assert MEMORY_QUERY_TOOL in new_tool_defs_names
411+
assert BuiltinTool.code_interpreter not in new_tool_defs_names

0 commit comments

Comments
 (0)