-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT #130069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5c8d5a3
392c278
c72bf96
c7cb787
fcce59f
ceed8b4
da0075b
004224d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,17 +87,17 @@ private boolean stepLloyd( | |
|
|
||
| for (int i = 0; i < sampleSize; i++) { | ||
| float[] vector = vectors.vectorValue(i); | ||
| int[] neighborOffsets = null; | ||
| int centroidIdx = -1; | ||
| final int assignment = assignments[i]; | ||
| final int bestCentroidOffset; | ||
| if (neighborhoods != null) { | ||
| neighborOffsets = neighborhoods.get(assignments[i]); | ||
| centroidIdx = assignments[i]; | ||
| bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment)); | ||
| } else { | ||
| bestCentroidOffset = getBestCentroid(centroids, vector); | ||
| } | ||
| int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets); | ||
| if (assignments[i] != bestCentroidOffset) { | ||
| if (assignment != bestCentroidOffset) { | ||
| assignments[i] = bestCentroidOffset; | ||
| changed = true; | ||
| } | ||
| assignments[i] = bestCentroidOffset; | ||
| centroidCounts[bestCentroidOffset]++; | ||
| for (int d = 0; d < dim; d++) { | ||
| nextCentroids[bestCentroidOffset][d] += vector[d]; | ||
|
|
@@ -116,23 +116,28 @@ private boolean stepLloyd( | |
| return changed; | ||
| } | ||
|
|
||
| int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { | ||
| int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { | ||
| int bestCentroidOffset = centroidIdx; | ||
| float minDsq; | ||
| if (centroidIdx > 0 && centroidIdx < centroids.length) { | ||
| minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); | ||
| } else { | ||
| minDsq = Float.MAX_VALUE; | ||
| assert centroidIdx >= 0 && centroidIdx < centroids.length; | ||
| float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); | ||
| for (int offset : centroidOffsets) { | ||
| float dsq = VectorUtil.squareDistance(vector, centroids[offset]); | ||
| if (dsq < minDsq) { | ||
| minDsq = dsq; | ||
| bestCentroidOffset = offset; | ||
| } | ||
| } | ||
| return bestCentroidOffset; | ||
| } | ||
|
|
||
| int k = 0; | ||
| for (int j = 0; j < centroids.length; j++) { | ||
| if (centroidOffsets == null || j == centroidOffsets[k]) { | ||
| float dsq = VectorUtil.squareDistance(vector, centroids[j]); | ||
| if (dsq < minDsq) { | ||
| minDsq = dsq; | ||
| bestCentroidOffset = j; | ||
| } | ||
| int getBestCentroid(float[][] centroids, float[] vector) { | ||
| int bestCentroidOffset = 0; | ||
| float minDsq = Float.MAX_VALUE; | ||
| for (int i = 0; i < centroids.length; i++) { | ||
| float dsq = VectorUtil.squareDistance(vector, centroids[i]); | ||
| if (dsq < minDsq) { | ||
| minDsq = dsq; | ||
| bestCentroidOffset = i; | ||
| } | ||
| } | ||
| return bestCentroidOffset; | ||
|
|
@@ -271,7 +276,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L | |
| return; | ||
| } | ||
|
|
||
| int[] assignments = new int[n]; | ||
| int[] assignments = kMeansIntermediate.assignments(); | ||
| assert assignments.length == n; | ||
| float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; | ||
| for (int i = 0; i < maxIterations; i++) { | ||
| if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) { | ||
|
|
@@ -291,7 +297,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L | |
| * @param maxIterations the max iterations to shift centroids | ||
| */ | ||
| public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is only used for tests and kinda silly now you can just get rid of this or I can clean it up in a subsequent PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's clean up in a follow up PR |
||
| KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); | ||
| KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc); | ||
| KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); | ||
| kMeans.cluster(vectors, kMeansIntermediate); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth including some of the runs you were doing in the PR comments just so we can look back at them if we need to to confirm recall wasn't hurt by doing this
I'll run a couple runs myself here real quick too to double check with a different model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ran this whole PR and just the sampling change only on glove 200d 1m, 3m and 10m and for both saw no major drops in recall