| 
15 | 15 | import org.apache.lucene.store.IndexInput;  | 
16 | 16 | import org.apache.lucene.store.IndexOutput;  | 
17 | 17 | import org.apache.lucene.store.MMapDirectory;  | 
18 |  | -import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;  | 
 | 18 | +import org.apache.lucene.util.VectorUtil;  | 
 | 19 | +import org.elasticsearch.index.codec.vectors.BQSpaceUtils;  | 
 | 20 | +import org.elasticsearch.index.codec.vectors.BQVectorUtils;  | 
 | 21 | +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;  | 
 | 22 | +import org.elasticsearch.simdvec.ES91Int4VectorsScorer;  | 
19 | 23 | import org.elasticsearch.simdvec.ES91OSQVectorsScorer;  | 
20 | 24 | 
 
  | 
21 |  | -import static org.hamcrest.Matchers.lessThan;  | 
 | 25 | +import java.io.IOException;  | 
22 | 26 | 
 
  | 
23 | 27 | public class ES91OSQVectorScorerTests extends BaseVectorizationTests {  | 
24 | 28 | 
 
  | 
25 | 29 |     public void testQuantizeScore() throws Exception {  | 
26 | 30 |         final int dimensions = random().nextInt(1, 2000);  | 
27 |  | -        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;  | 
 | 31 | +        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;  | 
28 | 32 |         final int numVectors = random().nextInt(1, 100);  | 
29 | 33 |         final byte[] vector = new byte[length];  | 
30 | 34 |         try (Directory dir = new MMapDirectory(createTempDir())) {  | 
@@ -53,102 +57,208 @@ public void testQuantizeScore() throws Exception {  | 
53 | 57 |     }  | 
54 | 58 | 
 
  | 
55 | 59 |     public void testScore() throws Exception {  | 
56 |  | -        final int maxDims = 512;  | 
 | 60 | +        final int maxDims = random().nextInt(1, 1000) * 2;  | 
57 | 61 |         final int dimensions = random().nextInt(1, maxDims);  | 
58 |  | -        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;  | 
59 |  | -        final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);  | 
60 |  | -        final byte[] vector = new byte[length];  | 
 | 62 | +        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;  | 
 | 63 | +        final int numVectors = random().nextInt(10, 50);  | 
 | 64 | +        float[][] vectors = new float[numVectors][dimensions];  | 
 | 65 | +        final int[] scratch = new int[dimensions];  | 
 | 66 | +        final byte[] qVector = new byte[length];  | 
 | 67 | +        final float[] centroid = new float[dimensions];  | 
 | 68 | +        VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());  | 
 | 69 | +        randomVector(centroid, similarityFunction);  | 
 | 70 | +        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);  | 
61 | 71 |         int padding = random().nextInt(100);  | 
62 | 72 |         byte[] paddingBytes = new byte[padding];  | 
63 | 73 |         try (Directory dir = new MMapDirectory(createTempDir())) {  | 
64 | 74 |             try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {  | 
65 | 75 |                 random().nextBytes(paddingBytes);  | 
66 | 76 |                 out.writeBytes(paddingBytes, 0, padding);  | 
 | 77 | +                for (float[] vector : vectors) {  | 
 | 78 | +                    randomVector(vector, similarityFunction);  | 
 | 79 | +                    OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(  | 
 | 80 | +                        vector.clone(),  | 
 | 81 | +                        scratch,  | 
 | 82 | +                        (byte) 1,  | 
 | 83 | +                        centroid  | 
 | 84 | +                    );  | 
 | 85 | +                    BQVectorUtils.packAsBinary(scratch, qVector);  | 
 | 86 | +                    out.writeBytes(qVector, 0, qVector.length);  | 
 | 87 | +                    out.writeInt(Float.floatToIntBits(result.lowerInterval()));  | 
 | 88 | +                    out.writeInt(Float.floatToIntBits(result.upperInterval()));  | 
 | 89 | +                    out.writeInt(Float.floatToIntBits(result.additionalCorrection()));  | 
 | 90 | +                    out.writeShort((short) result.quantizedComponentSum());  | 
 | 91 | +                }  | 
 | 92 | +            }  | 
 | 93 | +            final float[] query = new float[dimensions];  | 
 | 94 | +            randomVector(query, similarityFunction);  | 
 | 95 | +            OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(  | 
 | 96 | +                query.clone(),  | 
 | 97 | +                scratch,  | 
 | 98 | +                (byte) 4,  | 
 | 99 | +                centroid  | 
 | 100 | +            );  | 
 | 101 | +            final byte[] quantizeQuery = new byte[4 * length];  | 
 | 102 | +            BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);  | 
 | 103 | +            final float centroidDp = VectorUtil.dotProduct(centroid, centroid);  | 
 | 104 | +            final float[] floatScratch = new float[3];  | 
 | 105 | +            try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {  | 
 | 106 | +                in.seek(padding);  | 
 | 107 | +                assertEquals(in.length(), padding + (long) numVectors * (length + 14));  | 
 | 108 | +                final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * numVectors);  | 
 | 109 | +                // Work on a slice that has just the right number of bytes to make the test fail with an  | 
 | 110 | +                // index-out-of-bounds in case the implementation reads more than the allowed number of  | 
 | 111 | +                // padding bytes.  | 
67 | 112 |                 for (int i = 0; i < numVectors; i++) {  | 
68 |  | -                    random().nextBytes(vector);  | 
69 |  | -                    out.writeBytes(vector, 0, length);  | 
70 |  | -                    float lower = random().nextFloat();  | 
71 |  | -                    float upper = random().nextFloat() + lower / 2;  | 
72 |  | -                    float additionalCorrection = random().nextFloat();  | 
73 |  | -                    int targetComponentSum = randomIntBetween(0, dimensions / 2);  | 
74 |  | -                    out.writeInt(Float.floatToIntBits(lower));  | 
75 |  | -                    out.writeInt(Float.floatToIntBits(upper));  | 
76 |  | -                    out.writeShort((short) targetComponentSum);  | 
77 |  | -                    out.writeInt(Float.floatToIntBits(additionalCorrection));  | 
 | 113 | +                    final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);  | 
 | 114 | +                    final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);  | 
 | 115 | +                    long qDist = defaultScorer.quantizeScore(quantizeQuery);  | 
 | 116 | +                    slice.readFloats(floatScratch, 0, 3);  | 
 | 117 | +                    int quantizedComponentSum = slice.readShort();  | 
 | 118 | +                    float defaulScore = defaultScorer.score(  | 
 | 119 | +                        queryCorrections.lowerInterval(),  | 
 | 120 | +                        queryCorrections.upperInterval(),  | 
 | 121 | +                        queryCorrections.quantizedComponentSum(),  | 
 | 122 | +                        queryCorrections.additionalCorrection(),  | 
 | 123 | +                        similarityFunction,  | 
 | 124 | +                        centroidDp,  | 
 | 125 | +                        floatScratch[0],  | 
 | 126 | +                        floatScratch[1],  | 
 | 127 | +                        quantizedComponentSum,  | 
 | 128 | +                        floatScratch[2],  | 
 | 129 | +                        qDist  | 
 | 130 | +                    );  | 
 | 131 | +                    qDist = panamaScorer.quantizeScore(quantizeQuery);  | 
 | 132 | +                    in.readFloats(floatScratch, 0, 3);  | 
 | 133 | +                    quantizedComponentSum = in.readShort();  | 
 | 134 | +                    float panamaScore = panamaScorer.score(  | 
 | 135 | +                        queryCorrections.lowerInterval(),  | 
 | 136 | +                        queryCorrections.upperInterval(),  | 
 | 137 | +                        queryCorrections.quantizedComponentSum(),  | 
 | 138 | +                        queryCorrections.additionalCorrection(),  | 
 | 139 | +                        similarityFunction,  | 
 | 140 | +                        centroidDp,  | 
 | 141 | +                        floatScratch[0],  | 
 | 142 | +                        floatScratch[1],  | 
 | 143 | +                        quantizedComponentSum,  | 
 | 144 | +                        floatScratch[2],  | 
 | 145 | +                        qDist  | 
 | 146 | +                    );  | 
 | 147 | +                    assertEquals(defaulScore, panamaScore, 1e-2f);  | 
 | 148 | +                    assertEquals(((long) (i + 1) * (length + 14)), slice.getFilePointer());  | 
 | 149 | +                    assertEquals(padding + ((long) (i + 1) * (length + 14)), in.getFilePointer());  | 
78 | 150 |                 }  | 
79 | 151 |             }  | 
80 |  | -            final byte[] query = new byte[4 * length];  | 
81 |  | -            random().nextBytes(query);  | 
82 |  | -            float lower = random().nextFloat();  | 
83 |  | -            OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(  | 
84 |  | -                lower,  | 
85 |  | -                random().nextFloat() + lower / 2,  | 
86 |  | -                random().nextFloat(),  | 
87 |  | -                randomIntBetween(0, dimensions * 2)  | 
 | 152 | +        }  | 
 | 153 | +    }  | 
 | 154 | + | 
 | 155 | +    public void testScoreBulk() throws Exception {  | 
 | 156 | +        final int maxDims = random().nextInt(1, 1000) * 2;  | 
 | 157 | +        final int dimensions = random().nextInt(1, maxDims);  | 
 | 158 | +        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;  | 
 | 159 | +        final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);  | 
 | 160 | +        float[][] vectors = new float[numVectors][dimensions];  | 
 | 161 | +        final int[] scratch = new int[dimensions];  | 
 | 162 | +        final byte[] qVector = new byte[length];  | 
 | 163 | +        final float[] centroid = new float[dimensions];  | 
 | 164 | +        VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());  | 
 | 165 | +        randomVector(centroid, similarityFunction);  | 
 | 166 | +        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);  | 
 | 167 | +        int padding = random().nextInt(100);  | 
 | 168 | +        byte[] paddingBytes = new byte[padding];  | 
 | 169 | +        try (Directory dir = new MMapDirectory(createTempDir())) {  | 
 | 170 | +            try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {  | 
 | 171 | +                random().nextBytes(paddingBytes);  | 
 | 172 | +                out.writeBytes(paddingBytes, 0, padding);  | 
 | 173 | +                int limit = numVectors - ES91OSQVectorsScorer.BULK_SIZE + 1;  | 
 | 174 | +                OptimizedScalarQuantizer.QuantizationResult[] results =  | 
 | 175 | +                    new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];  | 
 | 176 | +                for (int i = 0; i < limit; i += ES91OSQVectorsScorer.BULK_SIZE) {  | 
 | 177 | +                    for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {  | 
 | 178 | +                        randomVector(vectors[i + j], similarityFunction);  | 
 | 179 | +                        results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), scratch, (byte) 1, centroid);  | 
 | 180 | +                        BQVectorUtils.packAsBinary(scratch, qVector);  | 
 | 181 | +                        out.writeBytes(qVector, 0, qVector.length);  | 
 | 182 | +                    }  | 
 | 183 | +                    writeCorrections(results, out);  | 
 | 184 | +                }  | 
 | 185 | +            }  | 
 | 186 | +            final float[] query = new float[dimensions];  | 
 | 187 | +            randomVector(query, similarityFunction);  | 
 | 188 | +            OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(  | 
 | 189 | +                query.clone(),  | 
 | 190 | +                scratch,  | 
 | 191 | +                (byte) 4,  | 
 | 192 | +                centroid  | 
88 | 193 |             );  | 
89 |  | -            final float centroidDp = random().nextFloat();  | 
90 |  | -            final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];  | 
91 |  | -            final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE];  | 
92 |  | -            for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {  | 
93 |  | -                try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {  | 
94 |  | -                    in.seek(padding);  | 
95 |  | -                    assertEquals(in.length(), padding + (long) numVectors * (length + 14));  | 
96 |  | -                    // Work on a slice that has just the right number of bytes to make the test fail with an  | 
97 |  | -                    // index-out-of-bounds in case the implementation reads more than the allowed number of  | 
98 |  | -                    // padding bytes.  | 
99 |  | -                    for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {  | 
100 |  | -                        final IndexInput slice = in.slice(  | 
101 |  | -                            "test",  | 
102 |  | -                            in.getFilePointer(),  | 
103 |  | -                            (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE  | 
104 |  | -                        );  | 
105 |  | -                        final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);  | 
106 |  | -                        final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);  | 
107 |  | -                        defaultScorer.scoreBulk(  | 
108 |  | -                            query,  | 
109 |  | -                            result.lowerInterval(),  | 
110 |  | -                            result.upperInterval(),  | 
111 |  | -                            result.quantizedComponentSum(),  | 
112 |  | -                            result.additionalCorrection(),  | 
113 |  | -                            similarityFunction,  | 
114 |  | -                            centroidDp,  | 
115 |  | -                            scores1  | 
116 |  | -                        );  | 
117 |  | -                        panamaScorer.scoreBulk(  | 
118 |  | -                            query,  | 
119 |  | -                            result.lowerInterval(),  | 
120 |  | -                            result.upperInterval(),  | 
121 |  | -                            result.quantizedComponentSum(),  | 
122 |  | -                            result.additionalCorrection(),  | 
123 |  | -                            similarityFunction,  | 
124 |  | -                            centroidDp,  | 
125 |  | -                            scores2  | 
126 |  | -                        );  | 
127 |  | -                        for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {  | 
128 |  | -                            if (scores1[j] == scores2[j]) {  | 
129 |  | -                                continue;  | 
130 |  | -                            }  | 
131 |  | -                            if (scores1[j] > (maxDims * Byte.MAX_VALUE)) {  | 
132 |  | -                                float diff = Math.abs(scores1[j] - scores2[j]);  | 
133 |  | -                                assertThat(  | 
134 |  | -                                    "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],  | 
135 |  | -                                    diff / scores1[j],  | 
136 |  | -                                    lessThan(1e-5f)  | 
137 |  | -                                );  | 
138 |  | -                                assertThat(  | 
139 |  | -                                    "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],  | 
140 |  | -                                    diff / scores2[j],  | 
141 |  | -                                    lessThan(1e-5f)  | 
142 |  | -                                );  | 
143 |  | -                            } else {  | 
144 |  | -                                assertEquals(scores1[j], scores2[j], 1e-2f);  | 
145 |  | -                            }  | 
146 |  | -                        }  | 
147 |  | -                        assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());  | 
148 |  | -                        assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());  | 
 | 194 | +            final byte[] quantizeQuery = new byte[4 * length];  | 
 | 195 | +            BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);  | 
 | 196 | +            final float centroidDp = VectorUtil.dotProduct(centroid, centroid);  | 
 | 197 | +            final float[] scoresDefault = new float[ES91OSQVectorsScorer.BULK_SIZE];  | 
 | 198 | +            final float[] scoresPanama = new float[ES91OSQVectorsScorer.BULK_SIZE];  | 
 | 199 | +            try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {  | 
 | 200 | +                in.seek(padding);  | 
 | 201 | +                assertEquals(in.length(), padding + (long) numVectors * (length + 14));  | 
 | 202 | +                // Work on a slice that has just the right number of bytes to make the test fail with an  | 
 | 203 | +                // index-out-of-bounds in case the implementation reads more than the allowed number of  | 
 | 204 | +                // padding bytes.  | 
 | 205 | +                for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {  | 
 | 206 | +                    final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE);  | 
 | 207 | +                    final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);  | 
 | 208 | +                    final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);  | 
 | 209 | +                    float defaultMaxScore = defaultScorer.scoreBulk(  | 
 | 210 | +                        quantizeQuery,  | 
 | 211 | +                        queryCorrections.lowerInterval(),  | 
 | 212 | +                        queryCorrections.upperInterval(),  | 
 | 213 | +                        queryCorrections.quantizedComponentSum(),  | 
 | 214 | +                        queryCorrections.additionalCorrection(),  | 
 | 215 | +                        similarityFunction,  | 
 | 216 | +                        centroidDp,  | 
 | 217 | +                        scoresDefault  | 
 | 218 | +                    );  | 
 | 219 | +                    float panamaMaxScore = panamaScorer.scoreBulk(  | 
 | 220 | +                        quantizeQuery,  | 
 | 221 | +                        queryCorrections.lowerInterval(),  | 
 | 222 | +                        queryCorrections.upperInterval(),  | 
 | 223 | +                        queryCorrections.quantizedComponentSum(),  | 
 | 224 | +                        queryCorrections.additionalCorrection(),  | 
 | 225 | +                        similarityFunction,  | 
 | 226 | +                        centroidDp,  | 
 | 227 | +                        scoresPanama  | 
 | 228 | +                    );  | 
 | 229 | +                    assertEquals(defaultMaxScore, panamaMaxScore, 1e-2f);  | 
 | 230 | +                    for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {  | 
 | 231 | +                        assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);  | 
149 | 232 |                     }  | 
 | 233 | +                    assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());  | 
 | 234 | +                    assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());  | 
150 | 235 |                 }  | 
151 | 236 |             }  | 
152 | 237 |         }  | 
153 | 238 |     }  | 
 | 239 | + | 
 | 240 | +    private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {  | 
 | 241 | +        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {  | 
 | 242 | +            out.writeInt(Float.floatToIntBits(correction.lowerInterval()));  | 
 | 243 | +        }  | 
 | 244 | +        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {  | 
 | 245 | +            out.writeInt(Float.floatToIntBits(correction.upperInterval()));  | 
 | 246 | +        }  | 
 | 247 | +        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {  | 
 | 248 | +            int targetComponentSum = correction.quantizedComponentSum();  | 
 | 249 | +            out.writeShort((short) targetComponentSum);  | 
 | 250 | +        }  | 
 | 251 | +        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {  | 
 | 252 | +            out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));  | 
 | 253 | +        }  | 
 | 254 | +    }  | 
 | 255 | + | 
 | 256 | +    private void randomVector(float[] vector, VectorSimilarityFunction vectorSimilarityFunction) {  | 
 | 257 | +        for (int i = 0; i < vector.length; i++) {  | 
 | 258 | +            vector[i] = random().nextFloat();  | 
 | 259 | +        }  | 
 | 260 | +        if (vectorSimilarityFunction != VectorSimilarityFunction.EUCLIDEAN) {  | 
 | 261 | +            VectorUtil.l2normalize(vector);  | 
 | 262 | +        }  | 
 | 263 | +    }  | 
154 | 264 | }  | 
0 commit comments