Skip to content

Commit e765b80

Browse files
committed
[owl] Make RAG more robust: Upon error use fallback search query and skip reranking (#828)
Backend - owl (API server) - Make RAG more robust: Upon error use fallback search query and skip reranking - If `search_columns` is not provided, default to all string columns
1 parent 161f452 commit e765b80

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

services/api/src/owl/db/gen_executor.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,16 +1224,28 @@ async def setup_rag(
12241224
project_id=project.id, table_id=kt_id, request_id=request_id
12251225
)
12261226
kt_cols = {c.column_id for c in kt.column_metadata if not c.is_state_column}
1227-
t0 = perf_counter()
1228-
fts_query, vs_query = await lm.generate_search_query(
1229-
messages=body.messages,
1230-
rag_params=body.rag_params,
1231-
**body.hyperparams,
1232-
)
1233-
cls._log(
1234-
f'Query rewrite using "{body.model}" took t={(perf_counter() - t0) * 1e3:,.2f} ms.',
1235-
request_id=request_id,
1236-
)
1227+
try:
1228+
t0 = perf_counter()
1229+
fts_query, vs_query = await lm.generate_search_query(
1230+
messages=body.messages,
1231+
rag_params=body.rag_params,
1232+
**body.hyperparams,
1233+
)
1234+
cls._log(
1235+
f'Query rewrite using "{body.model}" took t={(perf_counter() - t0) * 1e3:,.2f} ms.',
1236+
request_id=request_id,
1237+
)
1238+
except Exception as e:
1239+
cls._log(
1240+
f"Query rewrite failed with error: {repr(e)}. Using last user message as query.",
1241+
request_id=request_id,
1242+
)
1243+
# Fallback: use last user message
1244+
for msg in reversed(body.messages):
1245+
if msg.role == ChatRole.USER:
1246+
fts_query = msg.text_content
1247+
vs_query = msg.text_content
1248+
break
12371249
rows = await kt.hybrid_search(
12381250
fts_query=fts_query,
12391251
vs_query=vs_query,
@@ -1270,14 +1282,20 @@ async def setup_rag(
12701282
chunk.metadata["project_id"] = project.id
12711283
chunk.metadata["table_id"] = body.rag_params.table_id
12721284
if len(rows) > 0 and body.rag_params.reranking_model is not None:
1273-
order = (
1274-
await lm.rerank_documents(
1275-
model=body.rag_params.reranking_model,
1276-
query=vs_query,
1277-
documents=kt.rows_to_documents(rows),
1285+
try:
1286+
order = (
1287+
await lm.rerank_documents(
1288+
model=body.rag_params.reranking_model,
1289+
query=vs_query,
1290+
documents=kt.rows_to_documents(rows),
1291+
)
1292+
).results
1293+
chunks = [chunks[i.index] for i in order]
1294+
except Exception as e:
1295+
cls._log(
1296+
f"Reranking failed with error: {repr(e)}. Proceeding with original order.",
1297+
request_id=request_id,
12781298
)
1279-
).results
1280-
chunks = [chunks[i.index] for i in order]
12811299
chunks = chunks[: body.rag_params.k]
12821300
references = References(chunks=chunks, search_query=vs_query)
12831301
if body.messages[-1].role == ChatRole.USER:
@@ -1286,7 +1304,7 @@ async def setup_rag(
12861304
replacement_idx = -2
12871305
else:
12881306
raise BadInputError("The message list should end with user or assistant message.")
1289-
rag_prompt = await lm.generate_rag_prompt(
1307+
rag_prompt = await lm.make_rag_prompt(
12901308
messages=body.messages,
12911309
references=references,
12921310
inline_citations=body.rag_params.inline_citations,

services/api/src/owl/db/models/oss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ def _search_query_filter(
199199
search_columns: list[str] | None,
200200
) -> SelectBase:
201201
# Apply search filters
202-
if search_query and search_columns:
202+
if search_query:
203+
if not search_columns:
204+
search_columns = cls.str_cols()
203205
search_conditions = []
204206
for column_name in search_columns:
205207
if (column := getattr(cls, column_name, None)) is not None:

services/api/src/owl/utils/lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ async def generate_search_query(
15851585
queries[q_type] = generated_query
15861586
return queries["fts"], queries["vs"]
15871587

1588-
async def generate_rag_prompt(
1588+
async def make_rag_prompt(
15891589
self,
15901590
*,
15911591
messages: list[ChatEntry],

0 commit comments

Comments
 (0)