|
| 1 | +import json |
| 2 | +import sqlite3 |
| 3 | +from dataclasses import dataclass |
| 4 | +from datetime import datetime |
| 5 | +from hashlib import sha256 |
| 6 | +from pathlib import Path |
| 7 | +from typing import Iterable |
| 8 | + |
| 9 | +from alert_historian.ingestion.normalize import topic_slug |
| 10 | +from alert_historian.ingestion.schema import CanonicalAlertPayload |
| 11 | + |
| 12 | + |
| 13 | +TERMINAL_STATUSES = {"synced", "duplicate", "permanent_failed"} |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class PendingSyncItem: |
| 18 | + item_key: str |
| 19 | + message_key: str |
| 20 | + topic: str |
| 21 | + day: str |
| 22 | + url: str |
| 23 | + url_normalized: str |
| 24 | + title: str |
| 25 | + snippet: str |
| 26 | + source_domain: str |
| 27 | + source_message_id: str |
| 28 | + |
| 29 | + |
| 30 | +def make_message_key(source_account: str, source_message_id: str) -> str: |
| 31 | + return sha256(f"{source_account}|{source_message_id}".encode("utf-8")).hexdigest() |
| 32 | + |
| 33 | + |
| 34 | +def make_item_key(url_normalized: str, topic: str) -> str: |
| 35 | + return sha256(f"{url_normalized}|{topic_slug(topic)}".encode("utf-8")).hexdigest() |
| 36 | + |
| 37 | + |
| 38 | +class StateStore: |
| 39 | + def __init__(self, db_path: Path): |
| 40 | + self.db_path = db_path |
| 41 | + self.db_path.parent.mkdir(parents=True, exist_ok=True) |
| 42 | + self.conn = sqlite3.connect(str(self.db_path)) |
| 43 | + self.conn.row_factory = sqlite3.Row |
| 44 | + self._init_schema() |
| 45 | + |
| 46 | + def close(self) -> None: |
| 47 | + self.conn.close() |
| 48 | + |
| 49 | + def _init_schema(self) -> None: |
| 50 | + cur = self.conn.cursor() |
| 51 | + cur.execute(""" |
| 52 | + CREATE TABLE IF NOT EXISTS sync_checkpoint ( |
| 53 | + mailbox TEXT PRIMARY KEY, |
| 54 | + last_uid INTEGER NOT NULL, |
| 55 | + updated_at TEXT NOT NULL |
| 56 | + ) |
| 57 | + """) |
| 58 | + cur.execute(""" |
| 59 | + CREATE TABLE IF NOT EXISTS seen_messages ( |
| 60 | + msg_key TEXT PRIMARY KEY, |
| 61 | + source_message_id TEXT NOT NULL, |
| 62 | + source_account TEXT NOT NULL, |
| 63 | + received_at TEXT NOT NULL, |
| 64 | + max_uid INTEGER |
| 65 | + ) |
| 66 | + """) |
| 67 | + cur.execute(""" |
| 68 | + CREATE TABLE IF NOT EXISTS items ( |
| 69 | + item_key TEXT PRIMARY KEY, |
| 70 | + message_key TEXT NOT NULL, |
| 71 | + topic TEXT NOT NULL, |
| 72 | + day TEXT NOT NULL, |
| 73 | + payload_json TEXT NOT NULL, |
| 74 | + first_seen_at TEXT NOT NULL, |
| 75 | + last_seen_at TEXT NOT NULL |
| 76 | + ) |
| 77 | + """) |
| 78 | + cur.execute(""" |
| 79 | + CREATE TABLE IF NOT EXISTS sync_attempts ( |
| 80 | + id INTEGER PRIMARY KEY AUTOINCREMENT, |
| 81 | + item_key TEXT NOT NULL, |
| 82 | + run_id TEXT NOT NULL, |
| 83 | + status TEXT NOT NULL, |
| 84 | + attempts INTEGER NOT NULL, |
| 85 | + last_error TEXT, |
| 86 | + findfirst_bookmark_id INTEGER, |
| 87 | + updated_at TEXT NOT NULL |
| 88 | + ) |
| 89 | + """) |
| 90 | + self.conn.commit() |
| 91 | + |
| 92 | + def get_checkpoint(self, mailbox: str) -> int: |
| 93 | + cur = self.conn.execute("SELECT last_uid FROM sync_checkpoint WHERE mailbox = ?", (mailbox,)) |
| 94 | + row = cur.fetchone() |
| 95 | + return int(row["last_uid"]) if row else 0 |
| 96 | + |
| 97 | + def set_checkpoint(self, mailbox: str, last_uid: int) -> None: |
| 98 | + now = datetime.utcnow().isoformat() |
| 99 | + self.conn.execute( |
| 100 | + """ |
| 101 | + INSERT INTO sync_checkpoint(mailbox, last_uid, updated_at) |
| 102 | + VALUES(?, ?, ?) |
| 103 | + ON CONFLICT(mailbox) DO UPDATE SET last_uid=excluded.last_uid, updated_at=excluded.updated_at |
| 104 | + """, |
| 105 | + (mailbox, last_uid, now)) |
| 106 | + self.conn.commit() |
| 107 | + |
| 108 | + def is_message_seen(self, msg_key: str) -> bool: |
| 109 | + cur = self.conn.execute("SELECT 1 FROM seen_messages WHERE msg_key = ?", (msg_key,)) |
| 110 | + return cur.fetchone() is not None |
| 111 | + |
| 112 | + def record_message(self, msg_key: str, source_message_id: str, source_account: str, max_uid: int | None) -> None: |
| 113 | + self.conn.execute( |
| 114 | + """ |
| 115 | + INSERT OR IGNORE INTO seen_messages(msg_key, source_message_id, source_account, received_at, max_uid) |
| 116 | + VALUES (?, ?, ?, ?, ?) |
| 117 | + """, |
| 118 | + (msg_key, source_message_id, source_account, datetime.utcnow().isoformat(), max_uid)) |
| 119 | + self.conn.commit() |
| 120 | + |
| 121 | + def save_payloads(self, payloads: Iterable[CanonicalAlertPayload]) -> int: |
| 122 | + created = 0 |
| 123 | + for payload in payloads: |
| 124 | + msg_key = make_message_key(payload.source_account, payload.source_message_id) |
| 125 | + if self.is_message_seen(msg_key): |
| 126 | + continue |
| 127 | + max_uid = payload.raw_ref.uid if payload.raw_ref.uid is not None else None |
| 128 | + self.record_message(msg_key, payload.source_message_id, payload.source_account, max_uid) |
| 129 | + for item in payload.items: |
| 130 | + item_key = make_item_key(item.url_normalized, payload.alert_topic) |
| 131 | + now = datetime.utcnow().isoformat() |
| 132 | + day = payload.received_at.date().isoformat() |
| 133 | + payload_json = json.dumps({ |
| 134 | + "topic": payload.alert_topic, |
| 135 | + "source_message_id": payload.source_message_id, |
| 136 | + "url": item.url, |
| 137 | + "url_normalized": item.url_normalized, |
| 138 | + "title": item.title, |
| 139 | + "snippet": item.snippet, |
| 140 | + "source_domain": item.source_domain, |
| 141 | + "day": day, |
| 142 | + }) |
| 143 | + cur = self.conn.execute( |
| 144 | + """ |
| 145 | + INSERT OR IGNORE INTO items(item_key, message_key, topic, day, payload_json, first_seen_at, last_seen_at) |
| 146 | + VALUES (?, ?, ?, ?, ?, ?, ?) |
| 147 | + """, |
| 148 | + (item_key, msg_key, payload.alert_topic, day, payload_json, now, now)) |
| 149 | + if cur.rowcount: |
| 150 | + created += 1 |
| 151 | + else: |
| 152 | + self.conn.execute("UPDATE items SET last_seen_at=? WHERE item_key=?", (now, item_key)) |
| 153 | + self.conn.commit() |
| 154 | + return created |
| 155 | + |
| 156 | + def get_pending_items(self, run_id: str) -> list[PendingSyncItem]: |
| 157 | + cur = self.conn.execute(""" |
| 158 | + SELECT i.item_key, i.message_key, i.topic, i.day, i.payload_json |
| 159 | + FROM items i |
| 160 | + LEFT JOIN ( |
| 161 | + SELECT item_key, status |
| 162 | + FROM sync_attempts |
| 163 | + WHERE id IN (SELECT MAX(id) FROM sync_attempts GROUP BY item_key) |
| 164 | + ) sa ON sa.item_key = i.item_key |
| 165 | + WHERE sa.status IS NULL OR sa.status NOT IN ('synced', 'duplicate', 'permanent_failed') |
| 166 | + ORDER BY i.first_seen_at ASC |
| 167 | + """) |
| 168 | + items: list[PendingSyncItem] = [] |
| 169 | + for row in cur.fetchall(): |
| 170 | + payload = json.loads(row["payload_json"]) |
| 171 | + items.append(PendingSyncItem( |
| 172 | + item_key=row["item_key"], |
| 173 | + message_key=row["message_key"], |
| 174 | + topic=row["topic"], |
| 175 | + day=row["day"], |
| 176 | + url=payload["url"], |
| 177 | + url_normalized=payload["url_normalized"], |
| 178 | + title=payload["title"], |
| 179 | + snippet=payload["snippet"], |
| 180 | + source_domain=payload["source_domain"], |
| 181 | + source_message_id=payload["source_message_id"], |
| 182 | + )) |
| 183 | + return items |
| 184 | + |
| 185 | + def record_sync_attempt(self, item_key: str, run_id: str, status: str, attempts: int, last_error: str | None = None, |
| 186 | + bookmark_id: int | None = None) -> None: |
| 187 | + self.conn.execute( |
| 188 | + """ |
| 189 | + INSERT INTO sync_attempts(item_key, run_id, status, attempts, last_error, findfirst_bookmark_id, updated_at) |
| 190 | + VALUES (?, ?, ?, ?, ?, ?, ?) |
| 191 | + """, |
| 192 | + (item_key, run_id, status, attempts, last_error, bookmark_id, datetime.utcnow().isoformat())) |
| 193 | + self.conn.commit() |
| 194 | + |
| 195 | + def get_attempt_count(self, item_key: str) -> int: |
| 196 | + cur = self.conn.execute("SELECT MAX(attempts) as attempts FROM sync_attempts WHERE item_key = ?", (item_key,)) |
| 197 | + row = cur.fetchone() |
| 198 | + return int(row["attempts"]) if row and row["attempts"] is not None else 0 |
| 199 | + |
| 200 | + def checkpoint_if_terminal(self, mailbox: str) -> bool: |
| 201 | + cur = self.conn.execute(""" |
| 202 | + SELECT COUNT(*) AS pending_cnt |
| 203 | + FROM items i |
| 204 | + LEFT JOIN ( |
| 205 | + SELECT item_key, status |
| 206 | + FROM sync_attempts |
| 207 | + WHERE id IN (SELECT MAX(id) FROM sync_attempts GROUP BY item_key) |
| 208 | + ) sa ON sa.item_key = i.item_key |
| 209 | + WHERE sa.status IS NULL OR sa.status NOT IN ('synced', 'duplicate', 'permanent_failed') |
| 210 | + """) |
| 211 | + if int(cur.fetchone()["pending_cnt"]) > 0: |
| 212 | + return False |
| 213 | + cur = self.conn.execute("SELECT COALESCE(MAX(max_uid), 0) as max_uid FROM seen_messages") |
| 214 | + max_uid = int(cur.fetchone()["max_uid"]) |
| 215 | + self.set_checkpoint(mailbox, max_uid) |
| 216 | + return True |
| 217 | + |
| 218 | + def run_stats(self, run_id: str) -> dict[str, int]: |
| 219 | + cur = self.conn.execute(""" |
| 220 | + SELECT status, COUNT(*) AS cnt |
| 221 | + FROM sync_attempts |
| 222 | + WHERE run_id = ? |
| 223 | + GROUP BY status |
| 224 | + """, (run_id,)) |
| 225 | + stats = {row["status"]: int(row["cnt"]) for row in cur.fetchall()} |
| 226 | + total = sum(stats.values()) |
| 227 | + stats["total"] = total |
| 228 | + return stats |
| 229 | + |
| 230 | + def topic_links(self) -> dict[str, list[str]]: |
| 231 | + cur = self.conn.execute("SELECT topic, payload_json FROM items ORDER BY first_seen_at ASC") |
| 232 | + out: dict[str, list[str]] = {} |
| 233 | + for row in cur.fetchall(): |
| 234 | + payload = json.loads(row["payload_json"]) |
| 235 | + topic = row["topic"] |
| 236 | + if topic not in out: |
| 237 | + out[topic] = [] |
| 238 | + out[topic].append(payload["url"]) |
| 239 | + return out |
0 commit comments