Skip to content

Commit a925b54

Browse files
committed
addressing comments
1 parent 28508be commit a925b54

File tree

3 files changed

+135
-3
lines changed

3 files changed

+135
-3
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ public int size() {
8080

8181
@Override
8282
public float[] centroid(int centroidOrdinal) throws IOException {
83-
readQuantizedCentroid(centroidOrdinal);
83+
readQuantizedAndRawCentroid(centroidOrdinal);
8484
return centroid;
8585
}
8686

87-
private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
87+
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
8888
if (centroidOrdinal == currentCentroid) {
8989
return;
9090
}
@@ -97,7 +97,7 @@ private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
9797

9898
@Override
9999
public float score(int centroidOrdinal) throws IOException {
100-
readQuantizedCentroid(centroidOrdinal);
100+
readQuantizedAndRawCentroid(centroidOrdinal);
101101
return int4QuantizedScore(
102102
quantized,
103103
queryParams,

server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ private long encode(int node, float score) {
9595
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
9696
}
9797

98+
/** Returns the top element's node id. */
99+
int topNode() {
100+
return decodeNodeId(heap.top());
101+
}
102+
103+
/**
104+
* Returns the top element's node score. For the min heap this is the minimum score. For the max
105+
* heap this is the maximum score.
106+
*/
107+
float topScore() {
108+
return decodeScore(heap.top());
109+
}
110+
98111
private float decodeScore(long heapValue) {
99112
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
100113
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* @notice
3+
* Licensed to the Apache Software Foundation (ASF) under one or more
4+
* contributor license agreements. See the NOTICE file distributed with
5+
* this work for additional information regarding copyright ownership.
6+
* The ASF licenses this file to You under the Apache License, Version 2.0
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
* Modifications copyright (C) 2025 Elasticsearch B.V.
18+
*/
19+
20+
package org.elasticsearch.index.codec.vectors;
21+
22+
import org.elasticsearch.test.ESTestCase;
23+
24+
/**
25+
* copied and modified from Lucene
26+
*/
27+
public class NeighborQueueTests extends ESTestCase {
28+
public void testNeighborsProduct() {
29+
// make sure we have the sign correct
30+
NeighborQueue nn = new NeighborQueue(2, false);
31+
assertTrue(nn.insertWithOverflow(2, 0.5f));
32+
assertTrue(nn.insertWithOverflow(1, 0.2f));
33+
assertTrue(nn.insertWithOverflow(3, 1f));
34+
assertEquals(0.5f, nn.topScore(), 0);
35+
nn.pop();
36+
assertEquals(1f, nn.topScore(), 0);
37+
nn.pop();
38+
}
39+
40+
public void testNeighborsMaxHeap() {
41+
NeighborQueue nn = new NeighborQueue(2, true);
42+
assertTrue(nn.insertWithOverflow(2, 2));
43+
assertTrue(nn.insertWithOverflow(1, 1));
44+
assertFalse(nn.insertWithOverflow(3, 3));
45+
assertEquals(2f, nn.topScore(), 0);
46+
nn.pop();
47+
assertEquals(1f, nn.topScore(), 0);
48+
}
49+
50+
public void testTopMaxHeap() {
51+
NeighborQueue nn = new NeighborQueue(2, true);
52+
nn.add(1, 2);
53+
nn.add(2, 1);
54+
// lower scores are better; highest score on top
55+
assertEquals(2, nn.topScore(), 0);
56+
assertEquals(1, nn.topNode());
57+
}
58+
59+
public void testTopMinHeap() {
60+
NeighborQueue nn = new NeighborQueue(2, false);
61+
nn.add(1, 0.5f);
62+
nn.add(2, -0.5f);
63+
// higher scores are better; lowest score on top
64+
assertEquals(-0.5f, nn.topScore(), 0);
65+
assertEquals(2, nn.topNode());
66+
}
67+
68+
public void testClear() {
69+
NeighborQueue nn = new NeighborQueue(2, false);
70+
nn.add(1, 1.1f);
71+
nn.add(2, -2.2f);
72+
nn.clear();
73+
74+
assertEquals(0, nn.size());
75+
}
76+
77+
public void testMaxSizeQueue() {
78+
NeighborQueue nn = new NeighborQueue(2, false);
79+
nn.add(1, 1);
80+
nn.add(2, 2);
81+
assertEquals(2, nn.size());
82+
assertEquals(1, nn.topNode());
83+
84+
// insertWithOverflow does not extend the queue
85+
nn.insertWithOverflow(3, 3);
86+
assertEquals(2, nn.size());
87+
assertEquals(2, nn.topNode());
88+
89+
// add does extend the queue beyond maxSize
90+
nn.add(4, 1);
91+
assertEquals(3, nn.size());
92+
}
93+
94+
public void testUnboundedQueue() {
95+
NeighborQueue nn = new NeighborQueue(1, true);
96+
float maxScore = -2;
97+
int maxNode = -1;
98+
for (int i = 0; i < 256; i++) {
99+
// initial size is 32
100+
float score = random().nextFloat();
101+
if (score > maxScore) {
102+
maxScore = score;
103+
maxNode = i;
104+
}
105+
nn.add(i, score);
106+
}
107+
assertEquals(maxScore, nn.topScore(), 0);
108+
assertEquals(maxNode, nn.topNode());
109+
}
110+
111+
public void testInvalidArguments() {
112+
expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false));
113+
}
114+
115+
public void testToString() {
116+
assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString());
117+
}
118+
119+
}

0 commit comments

Comments
 (0)