|
24 | 24 | import org.apache.lucene.document.Field; |
25 | 25 | import org.apache.lucene.document.FieldType; |
26 | 26 | import org.apache.lucene.document.IntPoint; |
| 27 | +import org.apache.lucene.document.KnnByteVectorField; |
27 | 28 | import org.apache.lucene.document.KnnFloatVectorField; |
28 | 29 | import org.apache.lucene.document.LatLonShape; |
29 | 30 | import org.apache.lucene.document.LongPoint; |
|
67 | 68 | import org.elasticsearch.common.lucene.Lucene; |
68 | 69 | import org.elasticsearch.core.IOUtils; |
69 | 70 | import org.elasticsearch.index.codec.postings.ES812PostingsFormat; |
| 71 | +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; |
70 | 72 | import org.elasticsearch.index.shard.ShardId; |
71 | 73 | import org.elasticsearch.index.store.LuceneFilesExtensions; |
72 | 74 | import org.elasticsearch.test.ESTestCase; |
@@ -254,15 +256,27 @@ public void testKnnVectors() throws Exception { |
254 | 256 | VectorSimilarityFunction similarity = randomFrom(VectorSimilarityFunction.values()); |
255 | 257 | int numDocs = between(1000, 5000); |
256 | 258 | int dimension = between(10, 200); |
| 259 | + DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values()); |
257 | 260 |
|
258 | | - indexRandomly(dir, codec, numDocs, doc -> { |
259 | | - float[] vector = randomVector(dimension); |
260 | | - doc.add(new KnnFloatVectorField("vector", vector, similarity)); |
261 | | - }); |
| 261 | + if (elementType == DenseVectorFieldMapper.ElementType.FLOAT) { |
| 262 | + indexRandomly(dir, codec, numDocs, doc -> { |
| 263 | + float[] vector = randomVector(dimension); |
| 264 | + doc.add(new KnnFloatVectorField("vector", vector, similarity)); |
| 265 | + }); |
| 266 | + } else { |
| 267 | + indexRandomly(dir, codec, numDocs, doc -> { |
| 268 | + byte[] vector = new byte[dimension]; |
| 269 | + random().nextBytes(vector); |
| 270 | + doc.add(new KnnByteVectorField("vector", vector, similarity)); |
| 271 | + }); |
| 272 | + } |
262 | 273 | final IndexDiskUsageStats stats = IndexDiskUsageAnalyzer.analyze(testShardId(), lastCommit(dir), () -> {}); |
263 | 274 | logger.info("--> stats {}", stats); |
264 | 275 |
|
265 | | - long dataBytes = (long) numDocs * dimension * Float.BYTES; // size of flat vector data |
| 276 | + // expected size of flat vector data |
| 277 | + long dataBytes = elementType == DenseVectorFieldMapper.ElementType.FLOAT |
| 278 | + ? ((long) numDocs * dimension * Float.BYTES) |
| 279 | + : ((long) numDocs * dimension); |
266 | 280 | long indexBytesEstimate = (long) numDocs * (Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN / 4); // rough size of HNSW graph |
267 | 281 | assertThat("numDocs=" + numDocs + ";dimension=" + dimension, stats.total().getKnnVectorsBytes(), greaterThan(dataBytes)); |
268 | 282 | long connectionOverhead = stats.total().getKnnVectorsBytes() - dataBytes; |
|
0 commit comments