Skip to content

Commit 52635e8

Browse files
committed
chore: apply formatting and hash updates
1 parent 5ef6580 commit 52635e8

32 files changed

+131
-252
lines changed

lib/db/postgres_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import annotations
44

55
from contextlib import contextmanager
6-
from typing import Any, Iterable
6+
from typing import Any
7+
from collections.abc import Iterable
78

89
from psycopg_pool import ConnectionPool
910

lib/embeddings/google_client.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ def __init__(self):
3232
raise ValueError("GOOGLE_API_KEY not set in environment")
3333
self.client = genai.Client(api_key=config.embedding.api_key)
3434
else:
35-
if (
36-
not config.embedding.vertex_project
37-
or not config.embedding.vertex_location
38-
):
35+
if not config.embedding.vertex_project or not config.embedding.vertex_location:
3936
raise ValueError(
4037
"VERTEX_PROJECT and VERTEX_LOCATION must be set when EMBEDDING_PROVIDER=vertex_ai"
4138
)
@@ -68,9 +65,7 @@ def __init__(self):
6865

6966
# De-dupe while preserving order.
7067
seen: set[str] = set()
71-
self._model_candidates = [
72-
m for m in candidates if not (m in seen or seen.add(m))
73-
]
68+
self._model_candidates = [m for m in candidates if not (m in seen or seen.add(m))]
7469
self.model = self._model_candidates[0]
7570
self.dimensions = config.embedding.dimensions
7671
self.batch_size = config.embedding.batch_size
@@ -83,19 +78,15 @@ def _embed(self, *, text: str, task_type: str) -> Any:
8378
output_dimensionality=int(self.dimensions) if self.dimensions else None,
8479
)
8580

86-
return self.client.models.embed_content(
87-
model=self.model, contents=text, config=cfg
88-
)
81+
return self.client.models.embed_content(model=self.model, contents=text, config=cfg)
8982

9083
@retry(
9184
stop=stop_after_attempt(5),
9285
wait=wait_exponential(multiplier=1, min=2, max=10),
9386
retry=retry_if_exception_type(Exception),
9487
reraise=True,
9588
)
96-
def generate_embedding(
97-
self, text: str, task_type: str = "RETRIEVAL_DOCUMENT"
98-
) -> list[float]:
89+
def generate_embedding(self, text: str, task_type: str = "RETRIEVAL_DOCUMENT") -> list[float]:
9990
"""Generate embedding for a single text."""
10091

10192
last_err: Exception | None = None
@@ -143,9 +134,7 @@ def generate_embeddings_batch(
143134
f"({len(batch)} texts)..."
144135
)
145136

146-
batch_embeddings = [
147-
self.generate_embedding(text, task_type) for text in batch
148-
]
137+
batch_embeddings = [self.generate_embedding(text, task_type) for text in batch]
149138
all_embeddings.extend(batch_embeddings)
150139

151140
return all_embeddings

lib/gemini_finish_reason.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,4 @@ def raise_if_retryable_finish_reason(response: Any) -> None:
4545
return
4646

4747
finish_reason_name = normalize_finish_reason_name(finish_reason)
48-
raise RetryableFinishReasonError(
49-
f"Retryable finish_reason encountered: {finish_reason_name}"
50-
)
48+
raise RetryableFinishReasonError(f"Retryable finish_reason encountered: {finish_reason_name}")

lib/google_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,4 @@ def _safe_json_parse(self, response_text: str, context: str = "") -> dict[str, A
8585
return json.loads(response_text)
8686
except json.JSONDecodeError:
8787
preview = response_text[:200] if response_text else ""
88-
raise ValueError(
89-
f"Failed to parse JSON response ({context}). Preview: {preview}..."
90-
)
88+
raise ValueError(f"Failed to parse JSON response ({context}). Preview: {preview}...")

lib/knowledge_graph/base_kg_seeder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,7 @@ def _generate_embeddings_for_nodes(self, node_ids: Any) -> None:
256256
if not node_id_list:
257257
return
258258

259-
check_query = (
260-
"SELECT id, label FROM kg_nodes WHERE id = ANY(%s) AND embedding IS NULL"
261-
)
259+
check_query = "SELECT id, label FROM kg_nodes WHERE id = ANY(%s) AND embedding IS NULL"
262260
rows = self.postgres.execute_query(check_query, (node_id_list,))
263261

264262
labels_to_embed = {row[0]: row[1] for row in rows}
@@ -277,9 +275,7 @@ def _generate_embeddings_for_nodes(self, node_ids: Any) -> None:
277275
print(f"Error generating embeddings batch: {e}")
278276
return
279277

280-
update_rows = [
281-
(vector_literal(vec), nid) for nid, vec in zip(node_ids, embeddings)
282-
]
278+
update_rows = [(vector_literal(vec), nid) for nid, vec in zip(node_ids, embeddings)]
283279

284280
update_query = """
285281
UPDATE kg_nodes

lib/knowledge_graph/kg_extractor.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ def _parse_edges_from_llm_data(
227227
target_ref=target_ref,
228228
evidence=evidence,
229229
utterance_ids=utterance_ids,
230-
earliest_timestamp=earliest_timestamp_str
231-
or window.earliest_timestamp,
230+
earliest_timestamp=earliest_timestamp_str or window.earliest_timestamp,
232231
earliest_seconds=earliest_seconds or window.earliest_seconds,
233232
confidence=float(edge_data.get("confidence", 0.5)),
234233
)
@@ -295,9 +294,7 @@ def canonicalize_and_store(
295294
) -> dict[str, Any]:
296295
"""Canonicalize nodes and edges and store them in Postgres."""
297296

298-
def _normalize_speaker_ref(
299-
ref: str, window_speaker_ids: list[str]
300-
) -> str | None:
297+
def _normalize_speaker_ref(ref: str, window_speaker_ids: list[str]) -> str | None:
301298
ref = (ref or "").strip()
302299
if not ref:
303300
return None
@@ -353,9 +350,7 @@ def _normalize_speaker_ref(
353350
speaker_nodes_data = []
354351
for speaker_id in speaker_ids_seen:
355352
meta = speaker_meta.get(speaker_id, {})
356-
label = (
357-
meta.get("full_name") or meta.get("normalized_name") or speaker_id
358-
)
353+
label = meta.get("full_name") or meta.get("normalized_name") or speaker_id
359354
aliases = []
360355
for candidate in (
361356
meta.get("full_name"),
@@ -435,13 +430,11 @@ def _normalize_speaker_ref(
435430
target_id = temp_to_canonical.get(target_ref, target_ref)
436431

437432
if not (
438-
edge.source_ref.startswith("speaker_")
439-
or edge.source_ref in temp_to_canonical
433+
edge.source_ref.startswith("speaker_") or edge.source_ref in temp_to_canonical
440434
):
441435
stats["links_to_known"] += 1
442436
if not (
443-
edge.target_ref.startswith("speaker_")
444-
or edge.target_ref in temp_to_canonical
437+
edge.target_ref.startswith("speaker_") or edge.target_ref in temp_to_canonical
445438
):
446439
stats["links_to_known"] += 1
447440

@@ -554,13 +547,9 @@ def _embed_new_nodes(self, node_ids: list[str]) -> None:
554547

555548
ids = [x[0] for x in to_embed]
556549
texts = [x[1] for x in to_embed]
557-
embeddings = self.embedding.generate_embeddings_batch(
558-
texts, task_type="RETRIEVAL_DOCUMENT"
559-
)
550+
embeddings = self.embedding.generate_embeddings_batch(texts, task_type="RETRIEVAL_DOCUMENT")
560551

561-
update_rows = [
562-
(vector_literal(vec), node_id) for node_id, vec in zip(ids, embeddings)
563-
]
552+
update_rows = [(vector_literal(vec), node_id) for node_id, vec in zip(ids, embeddings)]
564553
self.postgres.execute_batch(
565554
"""
566555
UPDATE kg_nodes

lib/knowledge_graph/kg_store.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def canonicalize_and_store(
1919
*,
2020
postgres: PostgresClient,
2121
embedding: GoogleEmbeddingClient,
22-
results: list[
23-
tuple[Window, list[dict[str, Any]], list[dict[str, Any]], str, bool, str | None]
24-
],
22+
results: list[tuple[Window, list[dict[str, Any]], list[dict[str, Any]], str, bool, str | None]],
2523
youtube_video_id: str,
2624
kg_run_id: str,
2725
extractor_model: str,
@@ -187,13 +185,11 @@ def _normalize_speaker_ref(ref: str, window_speaker_ids: list[str]) -> str | Non
187185
target_id = temp_to_canonical.get(target_ref, target_ref)
188186

189187
if not (
190-
edge["source_ref"].startswith("speaker_")
191-
or edge["source_ref"] in temp_to_canonical
188+
edge["source_ref"].startswith("speaker_") or edge["source_ref"] in temp_to_canonical
192189
):
193190
stats["links_to_known"] += 1
194191
if not (
195-
edge["target_ref"].startswith("speaker_")
196-
or edge["target_ref"] in temp_to_canonical
192+
edge["target_ref"].startswith("speaker_") or edge["target_ref"] in temp_to_canonical
197193
):
198194
stats["links_to_known"] += 1
199195

@@ -274,9 +270,7 @@ def _normalize_speaker_ref(ref: str, window_speaker_ids: list[str]) -> str | Non
274270
)
275271
existing_ids = {row[0] for row in existing_rows}
276272

277-
filtered_edges = [
278-
e for e in edges_data if e[1] in existing_ids and e[3] in existing_ids
279-
]
273+
filtered_edges = [e for e in edges_data if e[1] in existing_ids and e[3] in existing_ids]
280274
stats["edges_skipped_missing_nodes"] = len(edges_data) - len(filtered_edges)
281275
stats["edges"] = len(filtered_edges)
282276

@@ -323,13 +317,9 @@ def _embed_new_nodes(
323317

324318
ids = [x[0] for x in to_embed]
325319
texts = [x[1] for x in to_embed]
326-
embeddings = embedding.generate_embeddings_batch(
327-
texts, task_type="RETRIEVAL_DOCUMENT"
328-
)
320+
embeddings = embedding.generate_embeddings_batch(texts, task_type="RETRIEVAL_DOCUMENT")
329321

330-
update_rows = [
331-
(vector_literal(vec), node_id) for node_id, vec in zip(ids, embeddings)
332-
]
322+
update_rows = [(vector_literal(vec), node_id) for node_id, vec in zip(ids, embeddings)]
333323
postgres.execute_batch(
334324
"""
335325
UPDATE kg_nodes

lib/knowledge_graph/model_compare.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,8 @@ def canonicalize_edges(
134134
out: list[dict[str, Any]] = []
135135

136136
for e in edges:
137-
source_ref = normalize_speaker_ref(
138-
str(e.get("source_ref", "")), window_speaker_ids
139-
)
140-
target_ref = normalize_speaker_ref(
141-
str(e.get("target_ref", "")), window_speaker_ids
142-
)
137+
source_ref = normalize_speaker_ref(str(e.get("source_ref", "")), window_speaker_ids)
138+
target_ref = normalize_speaker_ref(str(e.get("target_ref", "")), window_speaker_ids)
143139
if source_ref is None or target_ref is None:
144140
continue
145141

@@ -269,9 +265,7 @@ def collect_signatures(
269265
window_speaker_ids=r.window_speaker_ids,
270266
)
271267
for e in canon_edges:
272-
sigs.add(
273-
edge_signature_strict(e) if strict else edge_signature_loose(e)
274-
)
268+
sigs.add(edge_signature_strict(e) if strict else edge_signature_loose(e))
275269
return sigs
276270

277271
sigs_loose: dict[str, set[tuple]] = {}

lib/knowledge_graph/oss_kg_extractor.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,7 @@ def extract_from_concept_window(
261261
data_pass1 = self._parse_json_response(raw_response_pass1)
262262

263263
# Normalize pass1 output
264-
normalize_utterance_ids_in_data(
265-
data_pass1, youtube_video_id=youtube_video_id
266-
)
264+
normalize_utterance_ids_in_data(data_pass1, youtube_video_id=youtube_video_id)
267265
normalize_evidence_in_data(data_pass1, window_text=window.text)
268266

269267
pass1_parse_success = True
@@ -331,9 +329,7 @@ def extract_from_concept_window(
331329
data_pass2 = self._parse_json_response(raw_response_pass2)
332330

333331
# Normalize pass2 output
334-
normalize_utterance_ids_in_data(
335-
data_pass2, youtube_video_id=youtube_video_id
336-
)
332+
normalize_utterance_ids_in_data(data_pass2, youtube_video_id=youtube_video_id)
337333
normalize_evidence_in_data(data_pass2, window_text=window.text)
338334

339335
# Merge additions
@@ -356,9 +352,7 @@ def extract_from_concept_window(
356352
final_data = self._parse_json_response(raw_response_pass2)
357353

358354
# Normalize pass2 output
359-
normalize_utterance_ids_in_data(
360-
final_data, youtube_video_id=youtube_video_id
361-
)
355+
normalize_utterance_ids_in_data(final_data, youtube_video_id=youtube_video_id)
362356
normalize_evidence_in_data(final_data, window_text=window.text)
363357

364358
pass2_error = None
@@ -406,8 +400,7 @@ def extract_from_concept_window(
406400
"target_ref": target_ref,
407401
"evidence": evidence,
408402
"utterance_ids": utterance_ids,
409-
"earliest_timestamp": earliest_timestamp_str
410-
or window.earliest_timestamp,
403+
"earliest_timestamp": earliest_timestamp_str or window.earliest_timestamp,
411404
"earliest_seconds": earliest_seconds or window.earliest_seconds,
412405
"confidence": float(edge_data.get("confidence", 0.5)),
413406
}

lib/knowledge_graph/oss_two_pass.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from lib.knowledge_graph.model_compare import normalize_speaker_ref
1414

1515

16-
def normalize_utterance_ids_in_data(
17-
data: dict[str, Any], *, youtube_video_id: str
18-
) -> None:
16+
def normalize_utterance_ids_in_data(data: dict[str, Any], *, youtube_video_id: str) -> None:
1917
"""Normalize utterance_ids to full "{youtube_video_id}:<seconds>" strings.
2018
2119
Some models sometimes output bare seconds like "1851". The transcript windows
@@ -176,15 +174,11 @@ def validate_kg_llm_data(
176174
edges = data.get("edges")
177175
if not isinstance(nodes, list):
178176
issues.append(
179-
ValidationIssue(
180-
code="nodes_new_not_list", message="nodes_new must be a list"
181-
)
177+
ValidationIssue(code="nodes_new_not_list", message="nodes_new must be a list")
182178
)
183179
nodes = []
184180
if not isinstance(edges, list):
185-
issues.append(
186-
ValidationIssue(code="edges_not_list", message="edges must be a list")
187-
)
181+
issues.append(ValidationIssue(code="edges_not_list", message="edges must be a list"))
188182
edges = []
189183

190184
for i, n in enumerate(nodes):
@@ -293,11 +287,7 @@ def validate_kg_llm_data(
293287
)
294288
)
295289
else:
296-
bad = [
297-
str(uid)
298-
for uid in utterance_ids
299-
if str(uid) not in window_utterance_ids
300-
]
290+
bad = [str(uid) for uid in utterance_ids if str(uid) not in window_utterance_ids]
301291
if bad:
302292
issues.append(
303293
ValidationIssue(
@@ -343,9 +333,7 @@ def should_run_second_pass(
343333
if mode == TwoPassMode.ALWAYS:
344334
return True, "always"
345335
if mode == TwoPassMode.ON_FAIL:
346-
return (
347-
not pass1_parse_success
348-
), "parse_fail" if not pass1_parse_success else None
336+
return (not pass1_parse_success), "parse_fail" if not pass1_parse_success else None
349337
if mode == TwoPassMode.ON_LOW_EDGES:
350338
return (pass1_parse_success and edge_count < min_edges), (
351339
"low_edges" if pass1_parse_success and edge_count < min_edges else None
@@ -602,9 +590,7 @@ def merge_oss_additions(
602590
del_edges: list[Any] = del_edges_any if isinstance(del_edges_any, list) else []
603591

604592
existing_ids = {
605-
str(n.get("temp_id"))
606-
for n in base_nodes
607-
if isinstance(n, dict) and n.get("temp_id")
593+
str(n.get("temp_id")) for n in base_nodes if isinstance(n, dict) and n.get("temp_id")
608594
}
609595

610596
# Build remap for added nodes.

0 commit comments

Comments
 (0)