|
| 1 | +/** |
| 2 | + * @file ClusteringStrategy.ts |
| 3 | + * @description Agglomerative clustering to merge drifted speaker centroids. |
| 4 | + * |
| 5 | + * Over the course of a long session, a single speaker's vocal characteristics |
| 6 | + * may drift enough that two separate centroids are created for them. |
| 7 | + * {@link ClusteringStrategy.mergeClusters} detects this by computing pairwise |
| 8 | + * cosine similarity between all centroids and iteratively merging the closest |
| 9 | + * pair until either no pair exceeds the merge threshold or the centroid count |
| 10 | + * equals `expectedSpeakers`. |
| 11 | + * |
| 12 | + * @module diarization/ClusteringStrategy |
| 13 | + */ |
| 14 | + |
| 15 | +import { cosineSimilarity } from './SpeakerEmbeddingCache.js'; |
| 16 | + |
| 17 | +// --------------------------------------------------------------------------- |
| 18 | +// Main class |
| 19 | +// --------------------------------------------------------------------------- |
| 20 | + |
| 21 | +/** |
| 22 | + * Agglomerative speaker-centroid merging strategy. |
| 23 | + * |
| 24 | + * This is an optional post-processing step applied by {@link DiarizationSession} |
| 25 | + * when the number of tracked centroids exceeds the expected speaker count. |
| 26 | + * |
| 27 | + * @example |
| 28 | + * ```ts |
| 29 | + * const strategy = new ClusteringStrategy(0.85); |
| 30 | + * const mapping = strategy.mergeClusters(cache.centroids, 2); |
| 31 | + * // mapping: Map<oldId, canonicalId> |
| 32 | + * ``` |
| 33 | + */ |
| 34 | +export class ClusteringStrategy { |
| 35 | + // ------------------------------------------------------------------------- |
| 36 | + // Constructor |
| 37 | + // ------------------------------------------------------------------------- |
| 38 | + |
| 39 | + /** |
| 40 | + * @param mergeThreshold - Minimum cosine similarity between two centroids |
| 41 | + * for them to be considered the same speaker and merged. |
| 42 | + * @defaultValue 0.85 |
| 43 | + */ |
| 44 | + constructor(private readonly mergeThreshold: number = 0.85) {} |
| 45 | + |
| 46 | + // ------------------------------------------------------------------------- |
| 47 | + // Public API |
| 48 | + // ------------------------------------------------------------------------- |
| 49 | + |
| 50 | + /** |
| 51 | + * Identify centroid pairs that should be merged and return a renaming map. |
| 52 | + * |
| 53 | + * The algorithm: |
| 54 | + * 1. Compute all pairwise cosine similarities. |
| 55 | + * 2. If `expectedSpeakers` is set and the current count exceeds it, merge |
| 56 | + * the closest pair regardless of the threshold. |
| 57 | + * 3. Otherwise merge pairs that exceed `mergeThreshold`. |
| 58 | + * 4. Repeat until no further merges are possible or the count matches |
| 59 | + * `expectedSpeakers`. |
| 60 | + * |
| 61 | + * The returned `Map<string, string>` maps every old centroid ID that was |
| 62 | + * subsumed into a canonical ID. IDs that were not merged are not present in |
| 63 | + * the map. Callers should rename all occurrences of a key to its value. |
| 64 | + * |
| 65 | + * @param centroids - Current centroid snapshot (id → embedding). |
| 66 | + * @param expectedSpeakers - Optional upper bound on speaker count. |
| 67 | + * @returns Rename map: `oldId → canonicalId`. |
| 68 | + */ |
| 69 | + mergeClusters( |
| 70 | + centroids: Map<string, Float32Array>, |
| 71 | + expectedSpeakers?: number, |
| 72 | + ): Map<string, string> { |
| 73 | + // Build a mutable working copy so we can iteratively merge. |
| 74 | + const working = new Map<string, Float32Array>(centroids); |
| 75 | + // Accumulated rename mapping. |
| 76 | + const renameMap = new Map<string, string>(); |
| 77 | + |
| 78 | + while (true) { |
| 79 | + const ids = Array.from(working.keys()); |
| 80 | + const count = ids.length; |
| 81 | + |
| 82 | + // Nothing to merge. |
| 83 | + if (count < 2) break; |
| 84 | + |
| 85 | + // Find the closest pair. |
| 86 | + let bestSim = -Infinity; |
| 87 | + let bestI = 0; |
| 88 | + let bestJ = 1; |
| 89 | + |
| 90 | + for (let i = 0; i < count; i++) { |
| 91 | + for (let j = i + 1; j < count; j++) { |
| 92 | + const sim = cosineSimilarity(working.get(ids[i]!)!, working.get(ids[j]!)!); |
| 93 | + if (sim > bestSim) { |
| 94 | + bestSim = sim; |
| 95 | + bestI = i; |
| 96 | + bestJ = j; |
| 97 | + } |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + // Decide whether to merge. |
| 102 | + const shouldMergeDueToThreshold = bestSim >= this.mergeThreshold; |
| 103 | + const shouldMergeDueToCount = |
| 104 | + expectedSpeakers !== undefined && count > expectedSpeakers; |
| 105 | + |
| 106 | + if (!shouldMergeDueToThreshold && !shouldMergeDueToCount) break; |
| 107 | + |
| 108 | + // Merge ids[bestJ] into ids[bestI] (keep the lexicographically earlier |
| 109 | + // ID as the canonical one for determinism). |
| 110 | + const keepId = ids[bestI]!; |
| 111 | + const dropId = ids[bestJ]!; |
| 112 | + |
| 113 | + // Average the two centroids (equal weight — simple heuristic). |
| 114 | + const keepEmb = working.get(keepId)!; |
| 115 | + const dropEmb = working.get(dropId)!; |
| 116 | + const merged = new Float32Array(keepEmb.length); |
| 117 | + for (let k = 0; k < keepEmb.length; k++) { |
| 118 | + merged[k] = (keepEmb[k]! + dropEmb[k]!) / 2; |
| 119 | + } |
| 120 | + |
| 121 | + working.set(keepId, merged); |
| 122 | + working.delete(dropId); |
| 123 | + |
| 124 | + // Record the rename. Chase any existing mappings so the final map is |
| 125 | + // transitively resolved. |
| 126 | + renameMap.set(dropId, keepId); |
| 127 | + |
| 128 | + // Resolve transitive renames: if dropId was itself a canonical target of |
| 129 | + // an earlier merge, update those entries to point to keepId. |
| 130 | + for (const [old, target] of renameMap) { |
| 131 | + if (target === dropId) { |
| 132 | + renameMap.set(old, keepId); |
| 133 | + } |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + return renameMap; |
| 138 | + } |
| 139 | +} |
0 commit comments