|
17 | 17 | from jinja2 import Template, StrictUndefined |
18 | 18 | from sklearn.cluster import KMeans |
19 | 19 | from sklearn.metrics import silhouette_score |
| 20 | +from sklearn.metrics.pairwise import cosine_similarity |
20 | 21 |
|
21 | 22 | from consts.const import LANGUAGE |
22 | 23 |
|
@@ -77,7 +78,15 @@ def get_documents_from_es(index_name: str, es_core, sample_doc_count: int = 200) |
77 | 78 | "query": { |
78 | 79 | "term": {"path_or_url": path_or_url} |
79 | 80 | }, |
80 | | - "size": chunk_count # Get all chunks |
| 81 | + "size": chunk_count, # Get all chunks |
| 82 | + "sort": [ |
| 83 | + { |
| 84 | + "create_time": { |
| 85 | + "order": "asc", |
| 86 | + "missing": "_last" # Put documents without create_time at the end |
| 87 | + } |
| 88 | + } |
| 89 | + ] |
81 | 90 | } |
82 | 91 |
|
83 | 92 | chunks_response = es_core.client.search(index=index_name, body=chunks_query) |
@@ -124,10 +133,9 @@ def calculate_document_embedding(doc_chunks: List[Dict], use_weighted: bool = Tr |
124 | 133 | embeddings.append(np.array(chunk_embedding)) |
125 | 134 |
|
126 | 135 | if use_weighted: |
127 | | - # Weight by content length |
| 136 | + # Weight by content length only (removed position-based weight to reduce order dependency) |
128 | 137 | content_length = len(chunk.get('content', '')) |
129 | | - position_weight = 1.5 if len(embeddings) == 1 else 1.0 # First chunk has higher weight |
130 | | - weight = position_weight * content_length |
| 138 | + weight = content_length |
131 | 139 | weights.append(weight) |
132 | 140 |
|
133 | 141 | if not embeddings: |
@@ -217,6 +225,162 @@ def auto_determine_k(embeddings: np.ndarray, min_k: int = 3, max_k: int = 15) -> |
217 | 225 | return heuristic_k |
218 | 226 |
|
219 | 227 |
|
| 228 | +def merge_duplicate_documents_in_clusters(clusters: Dict[int, List[str]], doc_embeddings: Dict[str, np.ndarray], similarity_threshold: float = 0.98) -> Dict[int, List[str]]: |
| 229 | + """ |
| 230 | + Post-process clusters to merge duplicate documents (same content but different path_or_url) |
| 231 | + that were incorrectly split into different clusters. |
| 232 | + |
| 233 | + Args: |
| 234 | + clusters: Dictionary mapping cluster IDs to lists of document IDs |
| 235 | + doc_embeddings: Dictionary mapping document IDs to their embeddings |
| 236 | + similarity_threshold: Cosine similarity threshold to consider documents as duplicates (default: 0.98) |
| 237 | + |
| 238 | + Returns: |
| 239 | + Updated clusters dictionary with duplicate documents merged |
| 240 | + """ |
| 241 | + try: |
| 242 | + if not clusters or not doc_embeddings: |
| 243 | + return clusters |
| 244 | + |
| 245 | + # Skip merging if there's only one cluster (nothing to merge) |
| 246 | + if len(clusters) <= 1: |
| 247 | + return clusters |
| 248 | + |
| 249 | + # Build a mapping from doc_id to its current cluster |
| 250 | + doc_to_cluster = {} |
| 251 | + for cluster_id, doc_ids in clusters.items(): |
| 252 | + for doc_id in doc_ids: |
| 253 | + doc_to_cluster[doc_id] = cluster_id |
| 254 | + |
| 255 | + # Find duplicate pairs with high similarity |
| 256 | + doc_ids_list = list(doc_embeddings.keys()) |
| 257 | + merged_pairs = [] |
| 258 | + |
| 259 | + for i, doc_id1 in enumerate(doc_ids_list): |
| 260 | + if doc_id1 not in doc_embeddings: |
| 261 | + continue |
| 262 | + |
| 263 | + embedding1 = doc_embeddings[doc_id1] |
| 264 | + |
| 265 | + for j, doc_id2 in enumerate(doc_ids_list[i+1:], start=i+1): |
| 266 | + if doc_id2 not in doc_embeddings: |
| 267 | + continue |
| 268 | + |
| 269 | + embedding2 = doc_embeddings[doc_id2] |
| 270 | + |
| 271 | + # Calculate cosine similarity |
| 272 | + similarity = cosine_similarity( |
| 273 | + embedding1.reshape(1, -1), |
| 274 | + embedding2.reshape(1, -1) |
| 275 | + )[0][0] |
| 276 | + |
| 277 | + # If similarity is very high, they are likely duplicates |
| 278 | + if similarity >= similarity_threshold: |
| 279 | + cluster1 = doc_to_cluster.get(doc_id1) |
| 280 | + cluster2 = doc_to_cluster.get(doc_id2) |
| 281 | + |
| 282 | + # Only merge if they are in different clusters AND truly duplicates |
| 283 | + # Check both cosine similarity AND Euclidean distance to prevent false positives |
| 284 | + if cluster1 is not None and cluster2 is not None and cluster1 != cluster2: |
| 285 | + # Calculate Euclidean distance to ensure they're truly duplicates |
| 286 | + # Documents that are just in the same direction but far apart should not be merged |
| 287 | + euclidean_distance = np.linalg.norm(embedding1 - embedding2) |
| 288 | + |
| 289 | + # Normalize embeddings to get their magnitudes |
| 290 | + norm1 = np.linalg.norm(embedding1) |
| 291 | + norm2 = np.linalg.norm(embedding2) |
| 292 | + avg_norm = (norm1 + norm2) / 2.0 |
| 293 | + |
| 294 | + # Relative distance threshold: if distance is less than 1% of average magnitude, |
| 295 | + # they are likely true duplicates (same content, different path_or_url) |
| 296 | + # This prevents merging documents that are just in similar directions |
| 297 | + relative_distance_threshold = 0.01 * avg_norm if avg_norm > 0 else 0.1 |
| 298 | + |
| 299 | + if euclidean_distance <= relative_distance_threshold: |
| 300 | + merged_pairs.append((doc_id1, doc_id2, cluster1, cluster2, similarity)) |
| 301 | + logger.info(f"Found duplicate documents: {doc_id1} and {doc_id2} (similarity: {similarity:.4f}, distance: {euclidean_distance:.4f}) in different clusters {cluster1} and {cluster2}") |
| 302 | + |
| 303 | + # Merge duplicate documents into the same cluster |
| 304 | + if merged_pairs: |
| 305 | + logger.info(f"Merging {len(merged_pairs)} pairs of duplicate documents") |
| 306 | + |
| 307 | + # Build a graph of duplicate relationships using union-find |
| 308 | + parent = {} |
| 309 | + |
| 310 | + def find(x): |
| 311 | + if x not in parent: |
| 312 | + parent[x] = x |
| 313 | + if parent[x] != x: |
| 314 | + parent[x] = find(parent[x]) |
| 315 | + return parent[x] |
| 316 | + |
| 317 | + def union(x, y): |
| 318 | + px, py = find(x), find(y) |
| 319 | + if px != py: |
| 320 | + parent[px] = py |
| 321 | + |
| 322 | + # Build union-find structure |
| 323 | + for doc_id1, doc_id2, _, _, _ in merged_pairs: |
| 324 | + union(doc_id1, doc_id2) |
| 325 | + |
| 326 | + # Group documents by their root parent |
| 327 | + # Only include documents that are part of duplicate pairs |
| 328 | + duplicate_doc_ids = set() |
| 329 | + for doc_id1, doc_id2, _, _, _ in merged_pairs: |
| 330 | + duplicate_doc_ids.add(doc_id1) |
| 331 | + duplicate_doc_ids.add(doc_id2) |
| 332 | + |
| 333 | + groups = {} |
| 334 | + for doc_id in duplicate_doc_ids: |
| 335 | + root = find(doc_id) |
| 336 | + if root not in groups: |
| 337 | + groups[root] = [] |
| 338 | + groups[root].append(doc_id) |
| 339 | + |
| 340 | + # Merge each group into the same cluster |
| 341 | + for root, doc_group in groups.items(): |
| 342 | + if len(doc_group) < 2: |
| 343 | + continue |
| 344 | + |
| 345 | + # Find all clusters containing documents in this group |
| 346 | + clusters_in_group = set() |
| 347 | + for doc_id in doc_group: |
| 348 | + if doc_id in doc_to_cluster: |
| 349 | + clusters_in_group.add(doc_to_cluster[doc_id]) |
| 350 | + |
| 351 | + if len(clusters_in_group) > 1: |
| 352 | + # Merge all documents to the smallest cluster ID |
| 353 | + target_cluster = min(clusters_in_group) |
| 354 | + |
| 355 | + for doc_id in doc_group: |
| 356 | + current_cluster = doc_to_cluster.get(doc_id) |
| 357 | + if current_cluster is not None and current_cluster != target_cluster: |
| 358 | + # Move document to target cluster |
| 359 | + if current_cluster in clusters and doc_id in clusters[current_cluster]: |
| 360 | + clusters[current_cluster].remove(doc_id) |
| 361 | + if target_cluster not in clusters: |
| 362 | + clusters[target_cluster] = [] |
| 363 | + if doc_id not in clusters[target_cluster]: |
| 364 | + clusters[target_cluster].append(doc_id) |
| 365 | + doc_to_cluster[doc_id] = target_cluster |
| 366 | + logger.debug(f"Moved {doc_id} from cluster {current_cluster} to cluster {target_cluster}") |
| 367 | + |
| 368 | + # Remove empty clusters |
| 369 | + empty_clusters = [cid for cid, docs in clusters.items() if not docs] |
| 370 | + for cid in empty_clusters: |
| 371 | + del clusters[cid] |
| 372 | + logger.debug(f"Removed empty cluster {cid}") |
| 373 | + |
| 374 | + logger.info(f"Successfully merged duplicate documents. Final cluster count: {len(clusters)}") |
| 375 | + |
| 376 | + return clusters |
| 377 | + |
| 378 | + except Exception as e: |
| 379 | + logger.error(f"Error merging duplicate documents: {str(e)}", exc_info=True) |
| 380 | + # Return original clusters if merge fails |
| 381 | + return clusters |
| 382 | + |
| 383 | + |
220 | 384 | def kmeans_cluster_documents(doc_embeddings: Dict[str, np.ndarray], k: Optional[int] = None) -> Dict[int, List[str]]: |
221 | 385 | """ |
222 | 386 | Cluster documents using K-means |
@@ -266,6 +430,13 @@ def kmeans_cluster_documents(doc_embeddings: Dict[str, np.ndarray], k: Optional[ |
266 | 430 | for cluster_id, docs in clusters.items(): |
267 | 431 | logger.info(f"Cluster {cluster_id}: {len(docs)} documents") |
268 | 432 |
|
| 433 | + # Post-process: merge duplicate documents that were split into different clusters |
| 434 | + clusters = merge_duplicate_documents_in_clusters(clusters, doc_embeddings, similarity_threshold=0.98) |
| 435 | + |
| 436 | + # Log final cluster sizes after merge |
| 437 | + for cluster_id, docs in clusters.items(): |
| 438 | + logger.info(f"Final cluster {cluster_id}: {len(docs)} documents") |
| 439 | + |
269 | 440 | return clusters |
270 | 441 |
|
271 | 442 | except Exception as e: |
|
0 commit comments