Skip to content

Commit d599d1e

Browse files
authored
feat: use default rerank (#1200)
1 parent 63bb81e commit d599d1e

File tree

5 files changed

+218
-95
lines changed

5 files changed

+218
-95
lines changed

aperag/flow/runners/rerank.py

Lines changed: 133 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import List, Tuple
16+
from typing import List, Optional, Tuple
1717

1818
from pydantic import BaseModel, Field
1919

@@ -31,9 +31,12 @@
3131

3232

3333
class RerankInput(BaseModel):
34-
model: str = Field(..., description="Rerank model name")
35-
model_service_provider: str = Field(..., description="Model service provider")
36-
custom_llm_provider: str = Field(..., description="Custom LLM provider (e.g., 'jina_ai', 'openai')")
34+
use_rerank_service: bool = Field(default=True, description="Whether to use rerank service or fallback strategy")
35+
model: Optional[str] = Field(default=None, description="Rerank model name")
36+
model_service_provider: Optional[str] = Field(default=None, description="Model service provider")
37+
custom_llm_provider: Optional[str] = Field(
38+
default=None, description="Custom LLM provider (e.g., 'jina_ai', 'openai')"
39+
)
3740
docs: List[DocumentWithScore]
3841

3942

@@ -49,88 +52,145 @@ class RerankOutput(BaseModel):
4952
class RerankNodeRunner(BaseNodeRunner):
5053
async def run(self, ui: RerankInput, si: SystemInput) -> Tuple[RerankOutput, dict]:
5154
"""
52-
Run rerank node. ui: user input; si: system input (SystemInput).
53-
Returns (output, system_output)
55+
Smart rerank node:
56+
- use_rerank_service=False: directly use fallback strategy
57+
- use_rerank_service=True: try rerank service, fallback on failure
5458
"""
55-
query = si.query
5659
docs = ui.docs
57-
result = []
5860

5961
if not docs:
6062
logger.info("No documents to rerank, returning empty result")
63+
return RerankOutput(docs=[]), {}
64+
65+
# Strategy 1: If not using rerank service, directly use fallback strategy
66+
if not ui.use_rerank_service:
67+
logger.info("Rerank service disabled, using fallback strategy")
68+
result = self._apply_fallback_strategy(docs)
6169
return RerankOutput(docs=result), {}
6270

71+
# Strategy 2: Try to use rerank service
6372
try:
64-
# Validate input configuration
65-
if not ui.model_service_provider:
66-
raise InvalidConfigurationError(
67-
"model_service_provider", ui.model_service_provider, "Model service provider cannot be empty"
68-
)
69-
70-
if not ui.model:
71-
raise InvalidConfigurationError("model", ui.model, "Model name cannot be empty")
72-
73-
if not ui.custom_llm_provider:
74-
raise InvalidConfigurationError(
75-
"custom_llm_provider", ui.custom_llm_provider, "Custom LLM provider cannot be empty"
76-
)
77-
78-
# Get API key and base URL from user's model service provider settings
79-
api_key = await async_db_ops.query_provider_api_key(ui.model_service_provider, si.user)
80-
if not api_key:
81-
raise InvalidConfigurationError(
82-
"api_key", api_key, f"API KEY not found for LLM Provider:{ui.model_service_provider}"
83-
)
84-
85-
# Get base_url from LLMProvider
86-
try:
87-
llm_provider = await async_db_ops.query_llm_provider_by_name(ui.model_service_provider)
88-
if not llm_provider:
89-
raise ProviderNotFoundError(ui.model_service_provider, "Rerank")
90-
base_url = llm_provider.base_url
91-
except Exception as e:
92-
logger.error(f"Failed to query LLM provider '{ui.model_service_provider}': {str(e)}")
93-
raise ProviderNotFoundError(ui.model_service_provider, "Rerank") from e
94-
95-
if not base_url:
96-
raise InvalidConfigurationError(
97-
"base_url", base_url, f"Base URL not configured for provider '{ui.model_service_provider}'"
98-
)
99-
100-
# Create rerank service with configuration from model service provider
101-
rerank_service = RerankService(
102-
rerank_provider=ui.custom_llm_provider,
103-
rerank_model=ui.model,
104-
rerank_service_url=base_url,
105-
rerank_service_api_key=api_key,
73+
# Check configuration completeness
74+
if not self._is_rerank_config_valid(ui):
75+
logger.info("Rerank service configuration incomplete, using fallback strategy")
76+
result = self._apply_fallback_strategy(docs)
77+
return RerankOutput(docs=result), {}
78+
79+
# Execute actual rerank
80+
result = await self._perform_actual_rerank(ui, si)
81+
logger.info(f"Successfully reranked {len(result)} documents using rerank service")
82+
return RerankOutput(docs=result), {}
83+
84+
except (InvalidConfigurationError, ProviderNotFoundError) as e:
85+
logger.warning(f"Rerank service configuration error, using fallback strategy: {str(e)}")
86+
result = self._apply_fallback_strategy(docs)
87+
return RerankOutput(docs=result), {}
88+
89+
except RerankError as e:
90+
logger.warning(f"Rerank service operation failed, using fallback strategy: {str(e)}")
91+
result = self._apply_fallback_strategy(docs)
92+
return RerankOutput(docs=result), {}
93+
94+
except Exception as e:
95+
logger.error(f"Unexpected error during rerank service, using fallback strategy: {str(e)}")
96+
result = self._apply_fallback_strategy(docs)
97+
return RerankOutput(docs=result), {}
98+
99+
def _is_rerank_config_valid(self, ui: RerankInput) -> bool:
100+
"""Check if rerank service configuration is valid"""
101+
return (
102+
ui.model
103+
and ui.model.strip()
104+
and ui.model_service_provider
105+
and ui.model_service_provider.strip()
106+
and ui.custom_llm_provider
107+
and ui.custom_llm_provider.strip()
108+
)
109+
110+
async def _perform_actual_rerank(self, ui: RerankInput, si: SystemInput) -> List[DocumentWithScore]:
111+
"""Execute actual rerank operation"""
112+
query = si.query
113+
docs = ui.docs
114+
115+
# Validate configuration
116+
if not ui.model_service_provider:
117+
raise InvalidConfigurationError(
118+
"model_service_provider", ui.model_service_provider, "Model service provider cannot be empty"
106119
)
107120

108-
# Validate the service configuration
109-
rerank_service.validate_configuration()
121+
if not ui.model:
122+
raise InvalidConfigurationError("model", ui.model, "Model name cannot be empty")
110123

111-
logger.info(
112-
f"Using rerank service with provider: {ui.model_service_provider}, "
113-
f"model: {ui.model}, url: {base_url}, max_docs: {rerank_service.max_documents}"
124+
if not ui.custom_llm_provider:
125+
raise InvalidConfigurationError(
126+
"custom_llm_provider", ui.custom_llm_provider, "Custom LLM provider cannot be empty"
114127
)
115128

116-
# Perform reranking
117-
result = await rerank_service.async_rerank(query, docs)
118-
logger.info(f"Successfully reranked {len(result)} documents")
129+
# Get API key and base_url
130+
api_key = await async_db_ops.query_provider_api_key(ui.model_service_provider, si.user)
131+
if not api_key:
132+
raise InvalidConfigurationError(
133+
"api_key", api_key, f"API KEY not found for LLM Provider:{ui.model_service_provider}"
134+
)
119135

120-
except (InvalidConfigurationError, ProviderNotFoundError) as e:
121-
# Configuration errors - log and return empty result to gracefully degrade
122-
logger.error(f"Rerank configuration error: {str(e)}")
123-
# For flow execution, we gracefully degrade instead of failing the entire flow
124-
result = docs # Return original documents without reranking
125-
except RerankError as e:
126-
# Rerank-specific errors - log and return original documents
127-
logger.error(f"Rerank operation failed: {str(e)}")
128-
# For flow execution, we gracefully degrade instead of failing the entire flow
129-
result = docs # Return original documents without reranking
136+
try:
137+
llm_provider = await async_db_ops.query_llm_provider_by_name(ui.model_service_provider)
138+
if not llm_provider:
139+
raise ProviderNotFoundError(ui.model_service_provider, "Rerank")
140+
base_url = llm_provider.base_url
130141
except Exception as e:
131-
# Unexpected errors - log and return original documents
132-
logger.error(f"Unexpected error during rerank: {str(e)}")
133-
# For flow execution, we gracefully degrade instead of failing the entire flow
134-
result = docs # Return original documents without reranking
142+
logger.error(f"Failed to query LLM provider '{ui.model_service_provider}': {str(e)}")
143+
raise ProviderNotFoundError(ui.model_service_provider, "Rerank") from e
144+
145+
if not base_url:
146+
raise InvalidConfigurationError(
147+
"base_url", base_url, f"Base URL not configured for provider '{ui.model_service_provider}'"
148+
)
149+
150+
# Create and execute rerank service
151+
rerank_service = RerankService(
152+
rerank_provider=ui.custom_llm_provider,
153+
rerank_model=ui.model,
154+
rerank_service_url=base_url,
155+
rerank_service_api_key=api_key,
156+
)
157+
158+
rerank_service.validate_configuration()
159+
160+
logger.info(
161+
f"Using rerank service with provider: {ui.model_service_provider}, "
162+
f"model: {ui.model}, url: {base_url}, max_docs: {rerank_service.max_documents}"
163+
)
164+
165+
return await rerank_service.async_rerank(query, docs)
166+
167+
def _apply_fallback_strategy(self, docs: List[DocumentWithScore]) -> List[DocumentWithScore]:
168+
"""
169+
Apply fallback rerank strategy:
170+
1. Graph search results first (better quality, typically 1 result)
171+
2. Sort remaining vector and fulltext results by score in descending order
172+
"""
173+
if not docs:
174+
return docs
175+
176+
graph_results = []
177+
other_results = []
178+
179+
for doc in docs:
180+
recall_type = doc.metadata.get("recall_type", "")
181+
if recall_type == "graph_search":
182+
graph_results.append(doc)
183+
else:
184+
other_results.append(doc)
185+
186+
# Sort other results by score in descending order
187+
other_results.sort(key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
188+
189+
result = graph_results + other_results
190+
191+
logger.info(
192+
f"Applied fallback rerank strategy: {len(graph_results)} graph results, "
193+
f"{len(other_results)} other results sorted by score"
194+
)
135195

136-
return RerankOutput(docs=result), {}
196+
return result

aperag/mcp/server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ async def search_collection(
7272
use_vector_index: bool = True,
7373
use_fulltext_index: bool = False,
7474
use_graph_index: bool = True,
75+
rerank: bool = True,
7576
topk: int = 5,
7677
query_keywords: list[str] = None,
7778
) -> Dict[str, Any]:
@@ -84,7 +85,8 @@ async def search_collection(
8485
use_vector_index: Whether to use vector/semantic search (default: True)
8586
use_fulltext_index: Whether to use full-text keyword search (default: False)
8687
use_graph_index: Whether to use knowledge graph search (default: True)
87-
topk: Maximum number of results to return per search type (default: 10)
88+
rerank: Whether to enable reranking of search results for better relevance (default: True)
89+
topk: Maximum number of results to return per search type (default: 5)
8890
8991
Returns:
9092
Search results with relevant documents and metadata (SearchResult format)
@@ -146,7 +148,7 @@ class SearchResult(BaseModel):
146148
api_key = get_api_key()
147149

148150
# Build search request based on enabled search types
149-
search_data = {"query": query}
151+
search_data = {"query": query, "rerank": rerank}
150152

151153
# Add search configurations for enabled types
152154
if use_vector_index:
@@ -346,17 +348,18 @@ async def aperag_usage_guide() -> str:
346348
1. First, get available collections with essential information: `list_collections()`
347349
2. Choose a collection from the list
348350
3. Search the collection: `search_collection(collection_id="abc123", query="your question")`
349-
(By default, vector and graph search are enabled for optimal performance)
351+
(By default, vector search, graph search, and reranking are enabled for optimal performance)
350352
351353
## Search Types:
352354
You can enable/disable any combination of search methods:
353355
- **Vector search** (use_vector_index): Semantic similarity search using embeddings (default: True)
354356
- **Full-text search** (use_fulltext_index): Traditional keyword-based text search (default: False)
355357
- **Graph search** (use_graph_index): Knowledge graph-based search (default: True)
358+
- **Reranking** (rerank): AI-powered reranking for improved result relevance (default: True)
356359
357360
⚠️ **Important**: Full-text search can return large amounts of text content which may cause context window overflow with smaller LLM models. Use with caution and consider reducing topk when enabling fulltext search.
358361
359-
By default, vector and graph search are enabled for optimal balance of quality and context size.
362+
By default, vector search, graph search, and reranking are enabled for optimal balance of quality and context size.
360363
361364
## Example Workflow:
362365
```
@@ -367,13 +370,14 @@ async def aperag_usage_guide() -> str:
367370
# (collections.items contains collection ID, title, and description)
368371
collection_id = collections.items[0].id
369372
370-
# Step 3: Search with default methods (vector + graph)
373+
# Step 3: Search with default methods (vector + graph + rerank)
371374
results = search_collection(
372375
collection_id=collection_id,
373376
query="How to deploy applications?",
374377
use_vector_index=True,
375378
use_fulltext_index=False,
376379
use_graph_index=True,
380+
rerank=True,
377381
topk=5
378382
)
379383
@@ -384,6 +388,7 @@ async def aperag_usage_guide() -> str:
384388
use_vector_index=True,
385389
use_fulltext_index=False,
386390
use_graph_index=False,
391+
rerank=True, # Rerank still enabled for better results
387392
topk=10
388393
)
389394
@@ -394,6 +399,7 @@ async def aperag_usage_guide() -> str:
394399
use_vector_index=True,
395400
use_fulltext_index=True, # Enable with caution
396401
use_graph_index=True,
402+
rerank=True, # Rerank for optimal result ordering
397403
topk=3 # Use smaller topk to manage context size
398404
)
399405
```
@@ -485,6 +491,7 @@ async def aperag_usage_guide() -> str:
485491
internal_results = search_collection(
486492
collection_id=collections.items[0].id,
487493
query="AI developments",
494+
rerank=True, # Default rerank for better results
488495
topk=5
489496
)
490497

0 commit comments

Comments
 (0)