Skip to content

Commit 71c9fdc

Browse files
committed
fix: preserve graph evidence filtering and enable seed rerank by default
1 parent b00c52c commit 71c9fdc

File tree

4 files changed

+125
-21
lines changed

4 files changed

+125
-21
lines changed

lib/kg_hybrid_graph_rag.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def build_item_map(items: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
203203
if has_recency:
204204
boost += 0.03
205205

206-
if generic_terms and is_topical and not topic_terms:
206+
if generic_terms and is_topical:
207207
item_terms = set(re.findall(r"\b[a-zA-Z][a-zA-Z0-9-]{2,}\b", item_text))
208208
if not (topic_terms & item_terms):
209209
boost -= 0.05
@@ -555,6 +555,7 @@ def _retrieve_seed_nodes(
555555
seed_k: int,
556556
enable_rerank: bool = True,
557557
rerank_model: str = "gemini-2.0-flash",
558+
rerank_top_n: int = 40,
558559
query_embedding: list[float] | None = None,
559560
) -> list[dict[str, Any]]:
560561
vector_candidates: list[dict[str, Any]] = []
@@ -669,7 +670,7 @@ def _retrieve_seed_nodes(
669670
candidates=fused_candidates,
670671
query=query,
671672
model=rerank_model,
672-
top_n=min(50, seed_k * 3),
673+
top_n=max(5, min(int(rerank_top_n), max(50, seed_k * 3))),
673674
)
674675
except Exception:
675676
pass
@@ -686,18 +687,33 @@ def _retrieve_edges_hops_1(
686687
if not seed_ids:
687688
return []
688689
placeholders = ",".join(["%s"] * len(seed_ids))
689-
rows = postgres.execute_query(
690-
f"""
691-
SELECT id, source_id, predicate, predicate_raw, target_id,
692-
youtube_video_id, earliest_timestamp_str, earliest_seconds,
693-
utterance_ids, evidence, speaker_ids, confidence
694-
FROM kg_edges
695-
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})
696-
ORDER BY edge_rank_score DESC NULLS LAST, confidence DESC NULLS LAST, earliest_seconds ASC
697-
LIMIT %s
698-
""",
699-
tuple(seed_ids + seed_ids + [max_edges]),
700-
)
690+
params = tuple(seed_ids + seed_ids + [max_edges])
691+
try:
692+
rows = postgres.execute_query(
693+
f"""
694+
SELECT id, source_id, predicate, predicate_raw, target_id,
695+
youtube_video_id, earliest_timestamp_str, earliest_seconds,
696+
utterance_ids, evidence, speaker_ids, confidence, edge_rank_score
697+
FROM kg_edges
698+
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})
699+
ORDER BY edge_rank_score DESC NULLS LAST, confidence DESC NULLS LAST, earliest_seconds ASC
700+
LIMIT %s
701+
""",
702+
params,
703+
)
704+
except Exception:
705+
rows = postgres.execute_query(
706+
f"""
707+
SELECT id, source_id, predicate, predicate_raw, target_id,
708+
youtube_video_id, earliest_timestamp_str, earliest_seconds,
709+
utterance_ids, evidence, speaker_ids, confidence
710+
FROM kg_edges
711+
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})
712+
ORDER BY confidence DESC NULLS LAST, earliest_seconds ASC
713+
LIMIT %s
714+
""",
715+
params,
716+
)
701717
out: list[dict[str, Any]] = []
702718
for row in rows:
703719
out.append(
@@ -714,6 +730,9 @@ def _retrieve_edges_hops_1(
714730
"evidence": row[9],
715731
"speaker_ids": row[10] or [],
716732
"confidence": float(row[11]) if row[11] is not None else None,
733+
"edge_rank_score": float(row[12])
734+
if len(row) > 12 and row[12] is not None
735+
else None,
717736
}
718737
)
719738
return out
@@ -921,9 +940,11 @@ def kg_hybrid_graph_rag(
921940

922941
enable_rerank = getattr(config, "enable_seed_rerank", False)
923942
rerank_model = getattr(config, "seed_rerank_model", "gemini-2.0-flash")
943+
rerank_top_n = getattr(config, "seed_rerank_top_n", 40)
924944
except Exception:
925945
enable_rerank = False
926946
rerank_model = "gemini-2.0-flash"
947+
rerank_top_n = 40
927948

928949
seeds = _retrieve_seed_nodes(
929950
postgres=postgres,
@@ -932,6 +953,7 @@ def kg_hybrid_graph_rag(
932953
seed_k=seed_k,
933954
enable_rerank=enable_rerank,
934955
rerank_model=rerank_model,
956+
rerank_top_n=int(rerank_top_n),
935957
query_embedding=query_embedding,
936958
)
937959
seed_ids = [s["id"] for s in seeds]
@@ -989,12 +1011,22 @@ def kg_hybrid_graph_rag(
9891011
utterance_ids.append(uid)
9901012

9911013
edges_filtered: int = 0
1014+
edge_rank_filter_skipped_no_scores = False
9921015
if edge_rank_threshold is not None:
993-
edges_before_filter = len(edges)
994-
edges = [e for e in edges if e.get("edge_rank_score", 0.0) >= edge_rank_threshold]
995-
edges_filtered = edges_before_filter - len(edges)
996-
if edges_filtered > 0:
997-
edges = edges[:max_edges]
1016+
has_rank_scores = any(e.get("edge_rank_score") is not None for e in edges)
1017+
if has_rank_scores:
1018+
edges_before_filter = len(edges)
1019+
edges = [
1020+
e
1021+
for e in edges
1022+
if e.get("edge_rank_score") is not None
1023+
and float(e.get("edge_rank_score") or 0.0) >= edge_rank_threshold
1024+
]
1025+
edges_filtered = edges_before_filter - len(edges)
1026+
if edges_filtered > 0:
1027+
edges = edges[:max_edges]
1028+
else:
1029+
edge_rank_filter_skipped_no_scores = True
9981030

9991031
citations = _hydrate_citations(
10001032
postgres=postgres,
@@ -1013,6 +1045,7 @@ def kg_hybrid_graph_rag(
10131045
**debug_info,
10141046
"edge_rank_threshold": float(edge_rank_threshold),
10151047
"edges_filtered_by_threshold": edges_filtered,
1048+
"edge_rank_filter_skipped_no_scores": edge_rank_filter_skipped_no_scores,
10161049
}
10171050

10181051
return {

lib/utils/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ class AppConfig:
7575
"on",
7676
}
7777
gemini_model: str = os.getenv("GEMINI_MODEL", "gemini-3-flash-preview")
78-
enable_seed_rerank: bool = os.getenv("ENABLE_SEED_RERANK", "").lower() in {"1", "true", "on"}
78+
enable_seed_rerank: bool = os.getenv("ENABLE_SEED_RERANK", "1").lower() in {
79+
"1",
80+
"true",
81+
"on",
82+
}
7983
seed_rerank_model: str = os.getenv("SEED_RERANK_MODEL", "gemini-2.0-flash")
8084
seed_rerank_top_n: int = int(os.getenv("SEED_RERANK_TOP_N", "40"))
8185

tests/test_config_defaults_unit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
5+
6+
def test_enable_seed_rerank_should_default_to_true(monkeypatch) -> None:
7+
monkeypatch.delenv("ENABLE_SEED_RERANK", raising=False)
8+
9+
import lib.utils.config as config_module
10+
11+
config_module = importlib.reload(config_module)
12+
assert config_module.config.enable_seed_rerank is True

tests/test_kg_hybrid_graph_rag_unit.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def execute_query(self, sql: str, params: tuple[Any, ...] | None = None):
4040
if "FROM kg_edges" in sql:
4141
# (id, source_id, predicate, predicate_raw, target_id,
4242
# youtube_video_id, earliest_timestamp_str, earliest_seconds,
43-
# utterance_ids, evidence, speaker_ids, confidence)
43+
# utterance_ids, evidence, speaker_ids, confidence, edge_rank_score)
4444
return [
4545
(
4646
"kge_1",
@@ -55,6 +55,7 @@ def execute_query(self, sql: str, params: tuple[Any, ...] | None = None):
5555
"They discussed water management policy.",
5656
["s_test_1"],
5757
0.77,
58+
0.12,
5859
)
5960
]
6061

@@ -318,3 +319,57 @@ def test_kg_hybrid_graph_rag_with_bills_should_include_page_fragment_and_match_t
318319
assert bill["page_number"] == 12
319320
assert bill["source_url"].endswith("#page=12")
320321
assert "water" in bill["matched_terms"]
322+
323+
324+
def test_fuse_candidates_rrf_should_penalize_generic_governance_when_query_is_topical() -> None:
325+
from lib.kg_hybrid_graph_rag import _extract_query_intent, _fuse_candidates_rrf
326+
327+
vector_candidates = [
328+
{
329+
"id": "kg_ministers",
330+
"type": "foaf:Group",
331+
"label": "ministers",
332+
"aliases": ["minister"],
333+
},
334+
{
335+
"id": "kg_water",
336+
"type": "skos:Concept",
337+
"label": "water management",
338+
"aliases": ["water"],
339+
},
340+
]
341+
fulltext_candidates = []
342+
alias_candidates = []
343+
344+
query = "What did ministers say about water management recently"
345+
intent = _extract_query_intent(query)
346+
347+
fused = _fuse_candidates_rrf(
348+
vector_candidates=vector_candidates,
349+
fulltext_candidates=fulltext_candidates,
350+
alias_candidates=alias_candidates,
351+
query=query,
352+
intent=intent,
353+
)
354+
355+
assert fused[0]["id"] == "kg_water"
356+
ministers = next(item for item in fused if item["id"] == "kg_ministers")
357+
assert ministers["boost"] < 0.0
358+
359+
360+
def test_kg_hybrid_graph_rag_should_keep_edges_when_threshold_applies() -> None:
361+
from lib.kg_hybrid_graph_rag import kg_hybrid_graph_rag
362+
363+
out = kg_hybrid_graph_rag(
364+
postgres=_FakePostgres(),
365+
embedding_client=_FakeEmbedding(),
366+
query="water management",
367+
hops=1,
368+
seed_k=5,
369+
max_edges=20,
370+
max_citations=5,
371+
edge_rank_threshold=0.05,
372+
)
373+
374+
assert len(out["edges"]) == 1
375+
assert out["edges"][0]["id"] == "kge_1"

0 commit comments

Comments
 (0)