Skip to content

Commit a3a0077

Browse files
committed
fix: prioritize named speakers in KG retrieval
1 parent ab88eea commit a3a0077

File tree

4 files changed

+361
-38
lines changed

4 files changed

+361
-38
lines changed

lib/kg_agent_loop.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from google import genai
1212
from google.genai import types
1313

14+
from lib.id_generators import normalize_label
1415
from lib.kg_hybrid_graph_rag import kg_hybrid_graph_rag_with_bills as kg_hybrid_graph_rag
1516
from lib.utils.config import config
1617

@@ -83,6 +84,47 @@ def _format_tool_result_summary(result: dict[str, Any]) -> str:
8384
return ", ".join(out)
8485

8586

87+
def _augment_query_with_speakers(
88+
*, postgres: Any, query: str, user_message: str, max_speakers: int = 2
89+
) -> str:
90+
base = (query or "").strip()
91+
if not base:
92+
return base
93+
94+
message_norm = normalize_label(user_message)
95+
if not message_norm:
96+
return base
97+
98+
rows = postgres.execute_query(
99+
"""
100+
SELECT full_name, normalized_name
101+
FROM speakers
102+
WHERE %s LIKE '%' || normalized_name || '%'
103+
ORDER BY length(normalized_name) DESC
104+
LIMIT %s
105+
""",
106+
(message_norm, int(max_speakers)),
107+
)
108+
109+
if not rows:
110+
return base
111+
112+
query_norm = normalize_label(base)
113+
additions: list[str] = []
114+
for full_name, normalized_name in rows:
115+
candidate = (full_name or normalized_name or "").strip()
116+
if not candidate:
117+
continue
118+
if normalize_label(candidate) in query_norm:
119+
continue
120+
additions.append(candidate)
121+
122+
if not additions:
123+
return base
124+
125+
return f"{base} {' '.join(additions)}".strip()
126+
127+
86128
def _truncate_text(text: str, max_len: int = 300) -> str:
87129
"""Truncate text to max_len with ellipsis."""
88130
if not text or len(text) <= max_len:
@@ -794,10 +836,16 @@ async def run(self, *, user_message: str, history: list[dict[str, str]]) -> dict
794836
self.progress_callback(
795837
"searching", "Finding relevant debates (graph + citations)..."
796838
)
839+
base_query = str(fc.args.get("query", ""))
840+
resolved_query = _augment_query_with_speakers(
841+
postgres=self.postgres,
842+
query=base_query,
843+
user_message=user_message,
844+
)
797845
tool_result = kg_hybrid_graph_rag(
798846
postgres=self.postgres,
799847
embedding_client=self.embedding_client,
800-
query=str(fc.args.get("query", "")),
848+
query=resolved_query,
801849
hops=int(fc.args.get("hops", 1)),
802850
seed_k=int(fc.args.get("seed_k", 12)),
803851
max_edges=int(fc.args.get("max_edges", 90)),

lib/kg_hybrid_graph_rag.py

Lines changed: 152 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,24 @@ def _retrieve_seed_nodes(
676676
except Exception:
677677
pass
678678

679+
query_lower = query.lower()
680+
query_terms = _query_terms(query_lower)
681+
if len(query_terms) >= 2:
682+
683+
def _candidate_in_query(candidate: dict[str, Any]) -> bool:
684+
label = str(candidate.get("label") or "").lower().strip()
685+
if label and label in query_lower:
686+
return True
687+
for alias in candidate.get("aliases") or []:
688+
alias_str = str(alias or "").lower().strip()
689+
if alias_str and alias_str in query_lower:
690+
return True
691+
return False
692+
693+
matched = [c for c in fused_candidates if _candidate_in_query(c)]
694+
if matched:
695+
return matched[:seed_k]
696+
679697
return fused_candidates[:seed_k]
680698

681699

@@ -693,8 +711,9 @@ def _retrieve_edges_hops_1(
693711
rows = postgres.execute_query(
694712
f"""
695713
SELECT id, source_id, predicate, predicate_raw, target_id,
696-
youtube_video_id, earliest_timestamp_str, earliest_seconds,
697-
utterance_ids, evidence, speaker_ids, confidence, edge_rank_score
714+
source_kind, source_ref_id, youtube_video_id,
715+
earliest_timestamp_str, earliest_seconds,
716+
evidence_ids, utterance_ids, evidence, speaker_ids, confidence, edge_rank_score
698717
FROM kg_edges
699718
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})
700719
ORDER BY edge_rank_score DESC NULLS LAST, confidence DESC NULLS LAST, earliest_seconds ASC
@@ -706,8 +725,9 @@ def _retrieve_edges_hops_1(
706725
rows = postgres.execute_query(
707726
f"""
708727
SELECT id, source_id, predicate, predicate_raw, target_id,
709-
youtube_video_id, earliest_timestamp_str, earliest_seconds,
710-
utterance_ids, evidence, speaker_ids, confidence
728+
source_kind, source_ref_id, youtube_video_id,
729+
earliest_timestamp_str, earliest_seconds,
730+
evidence_ids, utterance_ids, evidence, speaker_ids, confidence
711731
FROM kg_edges
712732
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})
713733
ORDER BY confidence DESC NULLS LAST, earliest_seconds ASC
@@ -717,23 +737,50 @@ def _retrieve_edges_hops_1(
717737
)
718738
out: list[dict[str, Any]] = []
719739
for row in rows:
740+
if len(row) >= 16 and str(row[5]) in {"transcript", "bill"}:
741+
source_kind = str(row[5])
742+
source_ref_id = str(row[6] or "")
743+
youtube_video_id = row[7]
744+
earliest_timestamp_str = row[8]
745+
earliest_seconds = row[9]
746+
evidence_ids = row[10] or []
747+
legacy_utterance_ids = row[11] or []
748+
evidence = row[12]
749+
speaker_ids = row[13] or []
750+
confidence = row[14]
751+
edge_rank_score = row[15]
752+
else:
753+
# Legacy row shape before provenance cutover.
754+
source_kind = "transcript"
755+
source_ref_id = str(row[5] or "")
756+
youtube_video_id = row[5]
757+
earliest_timestamp_str = row[6]
758+
earliest_seconds = row[7]
759+
legacy_utterance_ids = row[8] or []
760+
evidence_ids = legacy_utterance_ids
761+
evidence = row[9]
762+
speaker_ids = row[10] or []
763+
confidence = row[11] if len(row) > 11 else None
764+
edge_rank_score = row[12] if len(row) > 12 else None
765+
720766
out.append(
721767
{
722768
"id": row[0],
723769
"source_id": row[1],
724770
"predicate": row[2],
725771
"predicate_raw": row[3],
726772
"target_id": row[4],
727-
"youtube_video_id": row[5],
728-
"earliest_timestamp_str": row[6],
729-
"earliest_seconds": int(row[7] or 0),
730-
"utterance_ids": row[8] or [],
731-
"evidence": row[9],
732-
"speaker_ids": row[10] or [],
733-
"confidence": float(row[11]) if row[11] is not None else None,
734-
"edge_rank_score": float(row[12])
735-
if len(row) > 12 and row[12] is not None
736-
else None,
773+
"source_kind": str(source_kind),
774+
"source_ref_id": str(source_ref_id or ""),
775+
"youtube_video_id": youtube_video_id,
776+
"earliest_timestamp_str": earliest_timestamp_str,
777+
"earliest_seconds": int(earliest_seconds or 0),
778+
"evidence_ids": evidence_ids,
779+
"utterance_ids": legacy_utterance_ids or [],
780+
"evidence": evidence,
781+
"speaker_ids": speaker_ids or [],
782+
"confidence": float(confidence) if confidence is not None else None,
783+
"edge_rank_score": float(edge_rank_score) if edge_rank_score is not None else None,
737784
}
738785
)
739786
return out
@@ -758,16 +805,78 @@ def _hydrate_nodes(
758805
return [{"id": r[0], "label": r[1], "type": r[2]} for r in rows]
759806

760807

808+
def _hydrate_bill_citations_from_ids(
809+
*,
810+
postgres: Any,
811+
bill_citation_ids: list[str],
812+
) -> list[dict[str, Any]]:
813+
if not bill_citation_ids:
814+
return []
815+
816+
out: list[dict[str, Any]] = []
817+
seen: set[str] = set()
818+
for cid in bill_citation_ids:
819+
if cid in seen:
820+
continue
821+
seen.add(cid)
822+
823+
parts = cid.split(":")
824+
if len(parts) != 3 or parts[0] != "bill":
825+
continue
826+
bill_id = parts[1]
827+
try:
828+
chunk_index = int(parts[2])
829+
except Exception:
830+
continue
831+
832+
rows = postgres.execute_query(
833+
"""
834+
SELECT b.id, b.bill_number, b.title, be.text, be.source_url, be.chunk_index, be.page_number
835+
FROM bill_excerpts be
836+
JOIN bills b ON b.id = be.bill_id
837+
WHERE be.bill_id = %s AND be.chunk_index = %s
838+
LIMIT 1
839+
""",
840+
(bill_id, chunk_index),
841+
)
842+
if not rows:
843+
continue
844+
845+
row = rows[0]
846+
page_number = int(row[6]) if row[6] is not None else None
847+
source_url = _url_with_page_fragment(str(row[4] or ""), page_number)
848+
out.append(
849+
{
850+
"citation_id": cid,
851+
"bill_id": str(row[0] or ""),
852+
"bill_number": row[1] or "",
853+
"bill_title": row[2] or "",
854+
"excerpt": row[3] or "",
855+
"source_url": source_url,
856+
"chunk_index": int(row[5] or 0),
857+
"page_number": page_number,
858+
"matched_terms": [],
859+
"score": 1.0,
860+
}
861+
)
862+
return out
863+
864+
761865
def _hydrate_citations(
762866
*,
763867
postgres: Any,
764-
utterance_ids: list[str],
868+
evidence_ids: list[str],
765869
max_citations: int,
766870
) -> list[dict[str, Any]]:
767-
if not utterance_ids:
871+
if not evidence_ids:
768872
return []
769-
utterance_ids = utterance_ids[:max_citations]
770-
placeholders = ",".join(["%s"] * len(utterance_ids))
873+
874+
transcript_ids = [eid for eid in evidence_ids if not str(eid).startswith("bill:")]
875+
transcript_ids = transcript_ids[:max_citations]
876+
if not transcript_ids:
877+
return []
878+
879+
placeholders = ",".join(["%s"] * len(transcript_ids))
771880
rows = postgres.execute_query(
772881
f"""
773882
SELECT s.id, s.text, s.seconds_since_start, s.timestamp_str,
@@ -800,7 +909,7 @@ def _hydrate_citations(
800909
LEFT JOIN speakers sp ON s.speaker_id = sp.id
801910
WHERE s.id IN ({placeholders})
802911
""",
803-
tuple(utterance_ids),
912+
tuple(transcript_ids),
804913
)
805914

806915
order_paper_idx = _load_order_paper_speaker_index(postgres=postgres)
@@ -1009,11 +1118,15 @@ def kg_hybrid_graph_rag(
10091118
e["target_label"] = target.get("label")
10101119
e["target_type"] = target.get("type")
10111120

1012-
utterance_ids: list[str] = []
1121+
evidence_ids: list[str] = []
1122+
bill_evidence_ids: list[str] = []
10131123
for e in edges:
1014-
for uid in e.get("utterance_ids", []) or []:
1015-
if uid not in utterance_ids:
1016-
utterance_ids.append(uid)
1124+
edge_evidence_ids = e.get("evidence_ids") or e.get("utterance_ids") or []
1125+
for evidence_id in edge_evidence_ids:
1126+
if evidence_id not in evidence_ids:
1127+
evidence_ids.append(evidence_id)
1128+
if str(evidence_id).startswith("bill:") and evidence_id not in bill_evidence_ids:
1129+
bill_evidence_ids.append(evidence_id)
10171130

10181131
edges_filtered: int = 0
10191132
edge_rank_filter_skipped_no_scores = False
@@ -1035,9 +1148,13 @@ def kg_hybrid_graph_rag(
10351148

10361149
citations = _hydrate_citations(
10371150
postgres=postgres,
1038-
utterance_ids=utterance_ids,
1151+
evidence_ids=evidence_ids,
10391152
max_citations=max_citations,
10401153
)
1154+
bill_citations_from_edges = _hydrate_bill_citations_from_ids(
1155+
postgres=postgres,
1156+
bill_citation_ids=bill_evidence_ids,
1157+
)
10411158

10421159
debug_info: dict[str, Any] = {
10431160
"seed_count": len(seeds),
@@ -1060,6 +1177,7 @@ def kg_hybrid_graph_rag(
10601177
"nodes": nodes,
10611178
"edges": edges,
10621179
"citations": citations,
1180+
"bill_citations_from_edges": bill_citations_from_edges,
10631181
"debug": debug_info,
10641182
}
10651183

@@ -1231,6 +1349,16 @@ def kg_hybrid_graph_rag_with_bills(
12311349
query_embedding=query_embedding,
12321350
)
12331351

1352+
for edge_citation in result.get("bill_citations_from_edges", []):
1353+
cid = str(edge_citation.get("citation_id") or "")
1354+
if not cid:
1355+
continue
1356+
exists = any(str(c.get("citation_id") or "") == cid for c in bill_citations)
1357+
if not exists:
1358+
bill_citations.append(edge_citation)
1359+
1360+
bill_citations = bill_citations[:max_bill_citations]
1361+
12341362
result["bill_citations"] = bill_citations
12351363
result["debug"]["bill_citation_count"] = len(bill_citations)
12361364

tests/test_kg_agent_loop_unit.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def execute_update(self, _sql: str, _params: Any = None):
6363
return None
6464

6565

66+
class _FakePostgresSpeakerMatch(_FakePostgres):
67+
def execute_query(self, _sql: str, _params: Any = None):
68+
return [("Tamaisha Eytle Harvey", "tamaisha eytle harvey")]
69+
70+
6671
class _FakeEmbedding:
6772
def generate_query_embedding(self, _query: str) -> list[float]:
6873
return [0.0] * 768
@@ -86,6 +91,22 @@ def test_system_prompt_includes_current_date_and_recency_guidance() -> None:
8691
assert "When the user asks for recent" in prompt
8792

8893

94+
def test_augment_query_with_speakers_appends_name() -> None:
95+
from lib.kg_agent_loop import _augment_query_with_speakers
96+
97+
postgres = _FakePostgresSpeakerMatch()
98+
query = "Future Barbados health tech"
99+
user_message = "What did Tamaisha Eytle Harvey say about Future Barbados?"
100+
101+
augmented = _augment_query_with_speakers(
102+
postgres=postgres,
103+
query=query,
104+
user_message=user_message,
105+
)
106+
107+
assert "Tamaisha Eytle Harvey" in augmented
108+
109+
89110
def test_agent_loop_runs_tool_then_answers():
90111
from lib.kg_agent_loop import KGAgentLoop
91112

0 commit comments

Comments
 (0)