Skip to content

Commit 4cdc969

Browse files
basic survey interface
1 parent da3b175 commit 4cdc969

File tree

17 files changed

+2115
-271
lines changed

17 files changed

+2115
-271
lines changed

package-lock.json

Lines changed: 53 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"@radix-ui/react-select": "^2.2.6",
2020
"@radix-ui/react-slot": "^1.2.4",
2121
"@radix-ui/react-switch": "^1.2.6",
22+
"@radix-ui/react-tooltip": "^1.2.8",
2223
"class-variance-authority": "^0.7.1",
2324
"clsx": "^2.1.1",
2425
"cmdk": "^1.1.1",

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ description = "UN Website Boilerplate"
55
readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
8+
"aiolimiter>=1.2.1",
9+
"backoff>=2.2.1",
810
"ipykernel>=7.1.0",
911
"joblib>=1.5.3",
12+
"numpy>=2.3.4",
13+
"openai>=2.15.0",
1014
"openpyxl>=3.1.5",
1115
"pandas>=2.3.3",
1216
"psycopg2-binary>=2.9.10",

python/generate_embeddings.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Generate vector embeddings for reports using Azure OpenAI.
4+
5+
This script:
6+
1. Fetches reports that don't have embeddings yet
7+
2. Generates embeddings using text-embedding-3-large (1024 dimensions)
8+
3. Stores embeddings in the reports table for similarity search
9+
10+
Usage:
11+
uv run python python/generate_embeddings.py
12+
"""
13+
14+
import os
15+
import asyncio
16+
import backoff
17+
import numpy as np
18+
import psycopg2
19+
from psycopg2.extras import execute_values
20+
from dotenv import load_dotenv
21+
from openai import AsyncAzureOpenAI, RateLimitError, APITimeoutError
22+
from tqdm.asyncio import tqdm_asyncio
23+
from typing import Literal
24+
import aiolimiter
25+
26+
# Load environment variables
27+
load_dotenv()
28+
29+
# Database config
30+
DATABASE_URL = os.environ.get("DATABASE_URL")
31+
if not DATABASE_URL:
32+
raise ValueError("DATABASE_URL environment variable is required")
33+
34+
DB_SCHEMA = os.environ.get("DB_SCHEMA", "sg_reports_survey")
35+
36+
# Azure OpenAI config
37+
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
38+
AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
39+
AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "2025-03-01-preview")
40+
41+
if not AZURE_OPENAI_ENDPOINT or not AZURE_OPENAI_API_KEY:
42+
raise ValueError("AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY are required")
43+
44+
# Initialize async client
45+
async_client = AsyncAzureOpenAI(
46+
azure_endpoint=AZURE_OPENAI_ENDPOINT,
47+
api_key=AZURE_OPENAI_API_KEY,
48+
api_version=AZURE_OPENAI_API_VERSION,
49+
)
50+
51+
# Rate limiter: 100 requests per minute to be safe
52+
rate_limiter = aiolimiter.AsyncLimiter(100, 60)
53+
54+
# Embedding config
55+
EMBEDDING_MODEL = "text-embedding-3-large"
56+
EMBEDDING_DIMENSIONS = 1024
57+
BATCH_SIZE = 64 # Process 64 texts per API call
58+
59+
60+
@backoff.on_exception(
61+
backoff.expo,
62+
(RateLimitError, APITimeoutError),
63+
max_tries=5,
64+
max_time=300,
65+
jitter=backoff.random_jitter,
66+
)
67+
async def embeddings_async(
68+
input_text: list[str],
69+
model: str = EMBEDDING_MODEL,
70+
encoding_format: Literal["float", "base64"] = "float",
71+
dimensions: int = EMBEDDING_DIMENSIONS,
72+
):
73+
"""
74+
Create embeddings using Azure OpenAI with async client.
75+
76+
Args:
77+
input_text: List of texts to embed
78+
model: Embedding model to use
79+
encoding_format: Format for embeddings
80+
dimensions: Number of dimensions for the embedding
81+
82+
Returns:
83+
tuple: (embeddings_data, usage)
84+
"""
85+
async with rate_limiter:
86+
response = await async_client.embeddings.create(
87+
input=input_text,
88+
model=model,
89+
encoding_format=encoding_format,
90+
dimensions=dimensions,
91+
)
92+
return [item.embedding for item in response.data], response.usage
93+
94+
95+
async def get_embeddings(texts: list[str]) -> np.ndarray:
96+
"""Get embeddings for texts, batching for API."""
97+
batches = [texts[i : i + BATCH_SIZE] for i in range(0, len(texts), BATCH_SIZE)]
98+
tasks = [
99+
embeddings_async(
100+
input_text=batch,
101+
model=EMBEDDING_MODEL,
102+
encoding_format="float",
103+
dimensions=EMBEDDING_DIMENSIONS,
104+
)
105+
for batch in batches
106+
]
107+
results = await tqdm_asyncio.gather(*tasks, desc="Getting embeddings")
108+
all_embeddings = []
109+
for embeddings_data, usage in results:
110+
all_embeddings.extend(embeddings_data)
111+
return np.array(all_embeddings)
112+
113+
114+
def prepare_text_for_embedding(row: dict) -> str:
115+
"""
116+
Prepare text for embedding from a report row.
117+
Combines title, subject terms, and full text (truncated).
118+
"""
119+
parts = []
120+
121+
# Add title
122+
if row.get("proper_title"):
123+
parts.append(f"Title: {row['proper_title']}")
124+
125+
# Add symbol
126+
if row.get("symbol"):
127+
parts.append(f"Symbol: {row['symbol']}")
128+
129+
# Add subject terms
130+
if row.get("subject_terms"):
131+
subjects = row["subject_terms"]
132+
if subjects:
133+
parts.append(f"Subjects: {', '.join(subjects)}")
134+
135+
# Add full text (truncated to ~6000 chars to stay within token limits)
136+
if row.get("text"):
137+
text = row["text"][:6000]
138+
parts.append(f"Content: {text}")
139+
140+
return "\n".join(parts)
141+
142+
143+
def fetch_reports_without_embeddings(conn, limit: int = None) -> list[dict]:
144+
"""Fetch reports that don't have embeddings yet."""
145+
print("Fetching reports without embeddings...")
146+
147+
cur = conn.cursor()
148+
149+
query = f"""
150+
SELECT id, symbol, proper_title, subject_terms, text
151+
FROM {DB_SCHEMA}.reports
152+
WHERE embedding IS NULL
153+
AND (proper_title IS NOT NULL OR text IS NOT NULL)
154+
ORDER BY id
155+
"""
156+
if limit:
157+
query += f" LIMIT {limit}"
158+
159+
cur.execute(query)
160+
columns = [desc[0] for desc in cur.description]
161+
rows = [dict(zip(columns, row)) for row in cur.fetchall()]
162+
163+
cur.close()
164+
print(f" Found {len(rows)} reports without embeddings")
165+
return rows
166+
167+
168+
def update_embeddings(conn, updates: list[tuple[int, list[float]]]):
169+
"""Update embeddings in the database."""
170+
print(f"Updating {len(updates)} embeddings in database...")
171+
172+
cur = conn.cursor()
173+
174+
# Update in batches
175+
batch_size = 100
176+
for i in range(0, len(updates), batch_size):
177+
batch = updates[i:i + batch_size]
178+
execute_values(
179+
cur,
180+
f"""
181+
UPDATE {DB_SCHEMA}.reports AS r
182+
SET embedding = v.embedding::vector, updated_at = NOW()
183+
FROM (VALUES %s) AS v(id, embedding)
184+
WHERE r.id = v.id
185+
""",
186+
[(id, f"[{','.join(map(str, emb))}]") for id, emb in batch],
187+
template="(%s, %s)"
188+
)
189+
190+
conn.commit()
191+
cur.close()
192+
print(" Done updating embeddings")
193+
194+
195+
async def main(limit: int = None, batch_process_size: int = 500):
196+
"""Main entry point."""
197+
print("=" * 60)
198+
print("Generating embeddings for reports")
199+
print("=" * 60)
200+
print(f"Model: {EMBEDDING_MODEL}")
201+
print(f"Dimensions: {EMBEDDING_DIMENSIONS}")
202+
print(f"Batch size: {BATCH_SIZE}")
203+
204+
# Connect to database
205+
print("\nConnecting to database...")
206+
conn = psycopg2.connect(DATABASE_URL)
207+
208+
try:
209+
while True:
210+
# Fetch reports without embeddings
211+
reports = fetch_reports_without_embeddings(conn, limit=batch_process_size)
212+
213+
if not reports:
214+
print("\nNo more reports to process!")
215+
break
216+
217+
# Prepare texts
218+
print("\nPreparing texts for embedding...")
219+
texts = [prepare_text_for_embedding(r) for r in reports]
220+
ids = [r["id"] for r in reports]
221+
222+
# Filter out empty texts
223+
valid_data = [(id, text) for id, text in zip(ids, texts) if text.strip()]
224+
if not valid_data:
225+
print("No valid texts to embed")
226+
break
227+
228+
valid_ids, valid_texts = zip(*valid_data)
229+
print(f" {len(valid_texts)} texts ready for embedding")
230+
231+
# Generate embeddings
232+
print("\nGenerating embeddings...")
233+
embeddings = await get_embeddings(list(valid_texts))
234+
235+
# Prepare updates
236+
updates = list(zip(valid_ids, embeddings.tolist()))
237+
238+
# Update database
239+
update_embeddings(conn, updates)
240+
241+
# If limit was set, only process one batch
242+
if limit:
243+
break
244+
245+
print(f"\nProcessed {len(updates)} reports, checking for more...")
246+
247+
# Print final stats
248+
cur = conn.cursor()
249+
cur.execute(f"""
250+
SELECT
251+
COUNT(*) as total,
252+
COUNT(embedding) as with_embedding
253+
FROM {DB_SCHEMA}.reports
254+
""")
255+
stats = cur.fetchone()
256+
cur.close()
257+
258+
print("\n" + "=" * 60)
259+
print("Final Stats:")
260+
print(f" Total reports: {stats[0]}")
261+
print(f" With embeddings: {stats[1]}")
262+
print(f" Coverage: {100 * stats[1] / stats[0]:.1f}%")
263+
print("=" * 60)
264+
265+
finally:
266+
conn.close()
267+
268+
269+
if __name__ == "__main__":
270+
import argparse
271+
272+
parser = argparse.ArgumentParser(description="Generate embeddings for reports")
273+
parser.add_argument("--limit", type=int, help="Limit number of reports to process (for testing)")
274+
parser.add_argument("--batch-size", type=int, default=500, help="Number of reports to process per database batch")
275+
args = parser.parse_args()
276+
277+
asyncio.run(main(limit=args.limit, batch_process_size=args.batch_size))

0 commit comments

Comments
 (0)