Skip to content

Commit 2e3c971

Browse files
authored
backend: Filter web search domains for Tavily (#713)
* wip * Add domain filtering for Tavily search * pr review fixes
1 parent f1ca125 commit 2e3c971

File tree

4 files changed

+112
-46
lines changed

4 files changed

+112
-46
lines changed

src/backend/config/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class CompassSettings(BaseSettings, BaseModel):
146146
)
147147

148148

149-
class WebSearchSettings(BaseSettings, BaseModel):
149+
class TavilySearchSettings(BaseSettings, BaseModel):
150150
model_config = SETTINGS_CONFIG
151151
api_key: Optional[str] = Field(
152152
default=None, validation_alias=AliasChoices("TAVILY_API_KEY", "api_key")
@@ -185,7 +185,7 @@ class ToolSettings(BaseSettings, BaseModel):
185185
python_interpreter: Optional[PythonToolSettings] = Field(
186186
default=PythonToolSettings()
187187
)
188-
web_search: Optional[WebSearchSettings] = Field(default=WebSearchSettings())
188+
tavily: Optional[TavilySearchSettings] = Field(default=TavilySearchSettings())
189189
wolfram_alpha: Optional[WolframAlphaSettings] = Field(
190190
default=WolframAlphaSettings()
191191
)

src/backend/crud/agent_tool_metadata.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def create_agent_tool_metadata(
2525

2626
def get_agent_tool_metadata_by_id(
2727
db: Session, agent_tool_metadata_id: str
28-
) -> AgentToolMetadata:
28+
) -> AgentToolMetadata | None:
2929
"""
3030
Get a agent tool metadata by its ID.
3131
@@ -61,6 +61,35 @@ def get_all_agent_tool_metadata_by_agent_id(
6161
)
6262

6363

64+
def get_agent_tool_metadata(
65+
db: Session,
66+
agent_id: str,
67+
tool_name: str,
68+
user_id: str,
69+
) -> AgentToolMetadata | None:
70+
"""
71+
Get a agent tool metadata.
72+
73+
Args:
74+
db (Session): Database session.
75+
agent_id (str): Agent ID.
76+
tool_name (str): Tool name.
77+
user_id (str): User ID.
78+
79+
Returns:
80+
AgentToolMetadata: Agent tool metadata.
81+
"""
82+
return (
83+
db.query(AgentToolMetadata)
84+
.filter(
85+
AgentToolMetadata.agent_id == agent_id,
86+
AgentToolMetadata.tool_name == tool_name,
87+
AgentToolMetadata.user_id == user_id,
88+
)
89+
.first()
90+
)
91+
92+
6493
def update_agent_tool_metadata(
6594
db: Session,
6695
agent_tool_metadata: AgentToolMetadata,

src/backend/routers/agent.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,43 @@ async def update_agent(
384384
return agent
385385

386386

387+
@router.delete("/{agent_id}", response_model=DeleteAgent)
388+
async def delete_agent(
389+
agent_id: str,
390+
session: DBSessionDep,
391+
ctx: Context = Depends(get_context),
392+
) -> DeleteAgent:
393+
"""
394+
Delete an agent by ID.
395+
396+
Args:
397+
agent_id (str): Agent ID.
398+
session (DBSessionDep): Database session.
399+
ctx (Context): Context object.
400+
401+
Returns:
402+
DeleteAgent: Empty response.
403+
404+
Raises:
405+
HTTPException: If the agent with the given ID is not found.
406+
"""
407+
user_id = ctx.get_user_id()
408+
ctx.with_event_type(MetricsMessageType.ASSISTANT_DELETED)
409+
agent = validate_agent_exists(session, agent_id, user_id)
410+
agent_schema = Agent.model_validate(agent)
411+
ctx.with_agent(agent_schema)
412+
ctx.with_metrics_agent(agent_to_metrics_agent(agent))
413+
414+
deleted = agent_crud.delete_agent(session, agent_id, user_id)
415+
if not deleted:
416+
raise HTTPException(status_code=401, detail="Could not delete Agent.")
417+
418+
return DeleteAgent()
419+
420+
421+
# Agent Tool Metadata endpoints
422+
423+
387424
async def handle_tool_metadata_update(
388425
agent: Agent,
389426
new_agent: Agent,
@@ -453,43 +490,6 @@ async def update_or_create_tool_metadata(
453490
create_agent_tool_metadata(session, agent.id, create_metadata_req, ctx)
454491

455492

456-
@router.delete("/{agent_id}", response_model=DeleteAgent)
457-
async def delete_agent(
458-
agent_id: str,
459-
session: DBSessionDep,
460-
ctx: Context = Depends(get_context),
461-
) -> DeleteAgent:
462-
"""
463-
Delete an agent by ID.
464-
465-
Args:
466-
agent_id (str): Agent ID.
467-
session (DBSessionDep): Database session.
468-
ctx (Context): Context object.
469-
470-
Returns:
471-
DeleteAgent: Empty response.
472-
473-
Raises:
474-
HTTPException: If the agent with the given ID is not found.
475-
"""
476-
user_id = ctx.get_user_id()
477-
ctx.with_event_type(MetricsMessageType.ASSISTANT_DELETED)
478-
agent = validate_agent_exists(session, agent_id, user_id)
479-
agent_schema = Agent.model_validate(agent)
480-
ctx.with_agent(agent_schema)
481-
ctx.with_metrics_agent(agent_to_metrics_agent(agent))
482-
483-
deleted = agent_crud.delete_agent(session, agent_id, user_id)
484-
if not deleted:
485-
raise HTTPException(status_code=401, detail="Could not delete Agent.")
486-
487-
return DeleteAgent()
488-
489-
490-
# Tool Metadata Endpoints
491-
492-
493493
@router.get("/{agent_id}/tool-metadata", response_model=list[AgentToolMetadataPublic])
494494
async def list_agent_tool_metadata(
495495
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
@@ -539,7 +539,7 @@ def create_agent_tool_metadata(
539539
ctx (Context): Context object.
540540
541541
Returns:
542-
AgentToolMetadata: Created agent tool metadata.
542+
AgentToolMetadataPublic: Created agent tool metadata.
543543
544544
Raises:
545545
HTTPException: If the agent tool metadata creation fails.

src/backend/tools/tavily.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,65 @@
44
from tavily import TavilyClient
55

66
from backend.config.settings import Settings
7+
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
8+
from backend.database_models.database import DBSessionDep
79
from backend.model_deployments.base import BaseDeployment
10+
from backend.schemas.context import Context
811
from backend.tools.base import BaseTool
912

1013

1114
class TavilyInternetSearch(BaseTool):
1215
NAME = "web_search"
13-
TAVILY_API_KEY = Settings().tools.web_search.api_key
16+
TAVILY_API_KEY = Settings().tools.tavily.api_key
17+
POST_RERANK_MAX_RESULTS = 6
1418

1519
def __init__(self):
1620
self.client = TavilyClient(api_key=self.TAVILY_API_KEY)
17-
self.num_results = 6
1821

1922
@classmethod
2023
def is_available(cls) -> bool:
2124
return cls.TAVILY_API_KEY is not None
2225

26+
def get_filtered_domains(self, session: DBSessionDep, ctx: Context) -> list[str]:
27+
agent_id = ctx.get_agent_id()
28+
user_id = ctx.get_user_id()
29+
30+
if not agent_id or not user_id:
31+
# Default for Tavily is None
32+
return []
33+
34+
agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata(
35+
db=session,
36+
agent_id=agent_id,
37+
tool_name=self.NAME,
38+
user_id=user_id,
39+
)
40+
41+
if not agent_tool_metadata:
42+
# Default for Tavily is None
43+
return []
44+
45+
return [
46+
artifact["domain"]
47+
for artifact in agent_tool_metadata.artifacts
48+
if "domain" in artifact
49+
]
50+
2351
async def call(
24-
self, parameters: dict, ctx: Any, **kwargs: Any
52+
self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any
2553
) -> List[Dict[str, Any]]:
54+
# Gather search parameters
2655
query = parameters.get("query", "")
56+
# Get domains set on Agent tool metadata artifacts
57+
filter_domains = None
58+
domains = self.get_filtered_domains(session, ctx)
59+
60+
if domains:
61+
filter_domains = domains
62+
63+
# Do search
2764
result = self.client.search(
28-
query=query, search_depth="advanced", include_raw_content=True
65+
query=query, search_depth="advanced", include_raw_content=True, include_domains=filter_domains
2966
)
3067

3168
if "results" not in result:
@@ -96,7 +133,7 @@ async def rerank_page_snippets(
96133
seen_urls.append(result["url"])
97134
reranked.append(result)
98135

99-
return reranked[: self.num_results]
136+
return reranked[:self.POST_RERANK_MAX_RESULTS]
100137

101138
def to_langchain_tool(self) -> TavilySearchResults:
102139
internet_search = TavilySearchResults()

0 commit comments

Comments
 (0)