Skip to content

Commit edf7b88

Browse files
frabonifaceMNIKIEMAEdouardCallet123
authored
Merge clustering improvements code (#49)
* add datasets for loading from HF * add datasets for loading from HF * add KMeans clustering * Effort to make TFIDF works * use clustering config in bertopic * reorg import * additions_clustering test ed * add clustering helper functions * add label in result * add text field * cleaned clustering scipts, added cluster naming and join of results and input * fix deps * fix rebase quirks --------- Co-authored-by: MNIKIEMA <nikiemamahamadi01@gmail.com> Co-authored-by: Edouard Callet <callet.edouard@gmail.com>
1 parent 30a6621 commit edf7b88

17 files changed

+309844
-1405
lines changed

policy_analysis/cluster.png

163 KB
Loading

policy_analysis/old/notebooks/BERTopic.ipynb

Lines changed: 89 additions & 443 deletions
Large diffs are not rendered by default.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
from datasets import load_dataset
3+
4+
# --- Configuration ---
5+
HF_DATASET_ID = "EdouardCallet/wsl-policy-10k"
6+
LABELS_PATH = 'src/policy_analysis/policies_clustering/results/clustering_experiment/labels.npy'
7+
OUTPUT_FILENAME = "wsl_policies_with_clusters.csv"
8+
9+
def main():
10+
print(f"Loading dataset: {HF_DATASET_ID}...")
11+
ds = load_dataset(HF_DATASET_ID, split="train")
12+
13+
print(f"Loading labels from: {LABELS_PATH}...")
14+
try:
15+
labels = np.load(LABELS_PATH, allow_pickle=True)
16+
except Exception as e:
17+
print(f"❌ Error loading .npy file: {e}")
18+
return
19+
20+
# --- Verification ---
21+
print("\n--- Checking Dimensions ---")
22+
n_rows = len(ds)
23+
n_labels = len(labels)
24+
print(f"Dataset rows: {n_rows}")
25+
print(f"Labels count: {n_labels}")
26+
27+
if n_rows != n_labels:
28+
print("\n⚠️ MISMATCH DETECTED ⚠️")
29+
print(f"You have {n_rows} documents but {n_labels} labels.")
30+
print("Reason: Your clustering likely ran on 'chunks' (sentences) rather than full documents.")
31+
print("To fix this, we need to map the chunks back to the parent documents.")
32+
return
33+
34+
# --- Merging ---
35+
print("\nMerging data...")
36+
df = ds.to_pandas()
37+
df['cluster_label'] = labels
38+
39+
# --- Inspecting Result ---
40+
print("\nTop 5 rows with new labels:")
41+
print(df[['cluster_label']].head(5))
42+
43+
# --- Saving ---
44+
print(f"\nSaving to {OUTPUT_FILENAME}...")
45+
df.to_csv(OUTPUT_FILENAME, index=False)
46+
print("✅ Success! File saved locally.")
47+
48+
# Optional: Push to Hub
49+
# print("Pushing to Hub...")
50+
# new_dataset = Dataset.from_pandas(df)
51+
# new_dataset.push_to_hub("EdouardCallet/wsl_10k_policy_and_taxonomy_clustered")
52+
53+
# --- EXECUTION BLOCK (This was likely missing) ---
54+
if __name__ == "__main__":
55+
print("Script started...")
56+
main()
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pandas as pd
2+
import openai
3+
from tqdm import tqdm
4+
import os
5+
from dotenv import load_dotenv
6+
7+
load_dotenv() # Load environment variables
8+
9+
# --- Configuration ---
10+
INPUT_CSV = "wsl_policies_with_clusters.csv"
11+
OUTPUT_CSV = "wsl_policies_clustered_and_named.csv"
12+
MODEL_NAME = "gpt-4o-mini"
13+
SAMPLES_PER_CLUSTER = 40
14+
15+
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
16+
17+
def get_cluster_name(cluster_id, texts):
18+
"""Sends a list of policy texts to the LLM and asks for a category name."""
19+
text_preview = "\n- ".join(texts)
20+
21+
prompt = (
22+
f"You are a policy analyst. Below is a list of policy excerpts that belong to the same cluster (Cluster ID: {cluster_id}).\n\n"
23+
f"POLICIES:\n{text_preview}\n\n"
24+
f"TASK: Provide a short, specific, and professional name for this cluster (max 5-7 words). "
25+
f"Do not use quotes. Just the name."
26+
)
27+
28+
try:
29+
response = client.chat.completions.create(
30+
model=MODEL_NAME,
31+
messages=[
32+
{"role": "system", "content": "You are a helpful taxonomist for public policy."},
33+
{"role": "user", "content": prompt}
34+
],
35+
temperature=0.0,
36+
)
37+
return response.choices[0].message.content.strip()
38+
except Exception as e:
39+
print(f"Error naming cluster {cluster_id}: {e}")
40+
return f"Cluster {cluster_id}"
41+
42+
def main():
43+
if not os.path.exists(INPUT_CSV):
44+
print(f"❌ Could not find {INPUT_CSV}. Run the linking script first!")
45+
return
46+
47+
print(f"Loading {INPUT_CSV}...")
48+
df = pd.read_csv(INPUT_CSV)
49+
50+
text_col = 'single_policy_item'
51+
if text_col not in df.columns:
52+
possible_cols = [c for c in df.columns if df[c].dtype == 'object']
53+
print(f"⚠️ Column '{text_col}' not found. Using '{possible_cols[0]}' instead.")
54+
text_col = possible_cols[0]
55+
56+
unique_clusters = df['cluster_label'].unique()
57+
unique_clusters = sorted([c for c in unique_clusters if pd.notna(c)])
58+
59+
print(f"Found {len(unique_clusters)} unique clusters.")
60+
print("Generating names (this may take a moment)...")
61+
62+
cluster_names_map = {}
63+
64+
for cluster_id in tqdm(unique_clusters, desc="Naming Clusters"):
65+
# Filter for current cluster
66+
cluster_data = df[df['cluster_label'] == cluster_id]
67+
68+
# Get valid texts only
69+
valid_texts = cluster_data[text_col].dropna()
70+
71+
# Skip if no text available
72+
if len(valid_texts) == 0:
73+
cluster_names_map[cluster_id] = f"Cluster {cluster_id}"
74+
continue
75+
76+
# Calculate safe sample size based on VALID texts count
77+
n_samples = min(SAMPLES_PER_CLUSTER, len(valid_texts))
78+
79+
# Sample
80+
sample_texts = valid_texts.sample(n=n_samples, random_state=42).tolist()
81+
82+
# Call LLM
83+
name = get_cluster_name(cluster_id, sample_texts)
84+
cluster_names_map[cluster_id] = name
85+
86+
print("\nApplying names to dataset...")
87+
df['cluster_name'] = df['cluster_label'].map(cluster_names_map)
88+
89+
print("\n--- Sample of Generated Names ---")
90+
print(df[['cluster_label', 'cluster_name']].drop_duplicates().head(10))
91+
92+
print(f"\nSaving to {OUTPUT_CSV}...")
93+
df.to_csv(OUTPUT_CSV, index=False)
94+
print("✅ Done!")
95+
96+
if __name__ == "__main__":
97+
main()

0 commit comments

Comments
 (0)