|
37 | 37 | import java.util.concurrent.TimeUnit; |
38 | 38 | import java.util.function.Function; |
39 | 39 | import java.util.function.Predicate; |
| 40 | +import java.util.stream.Collectors; |
40 | 41 | import java.util.stream.IntStream; |
41 | 42 |
|
42 | 43 | import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; |
@@ -378,6 +379,136 @@ public void testLarge() throws IOException { |
378 | 379 | } |
379 | 380 | } |
380 | 381 |
|
| 382 | + // Test that the scorer works well when the IndexInput is greater than the directory segment chunk size |
| 383 | + public void testDatasetGreaterThanChunkSize() throws IOException { |
| 384 | + assumeTrue(notSupportedMsg(), supported()); |
| 385 | + var factory = AbstractVectorTestCase.factory.get(); |
| 386 | + |
| 387 | + try (Directory dir = new MMapDirectory(createTempDir("testDatasetGreaterThanChunkSize"), 8192)) { |
| 388 | + final int dims = 1024; |
| 389 | + final int size = 128; |
| 390 | + final float correction = randomFloat(); |
| 391 | + |
| 392 | + String fileName = "testDatasetGreaterThanChunkSize-" + dims; |
| 393 | + logger.info("Testing " + fileName); |
| 394 | + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { |
| 395 | + for (int i = 0; i < size; i++) { |
| 396 | + var vec = vector(i, dims); |
| 397 | + var off = (float) i; |
| 398 | + out.writeBytes(vec, 0, vec.length); |
| 399 | + out.writeInt(Float.floatToIntBits(off)); |
| 400 | + } |
| 401 | + } |
| 402 | + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { |
| 403 | + for (int times = 0; times < TIMES; times++) { |
| 404 | + int idx0 = randomIntBetween(0, size - 1); |
| 405 | + int idx1 = size - 1; |
| 406 | + float off0 = (float) idx0; |
| 407 | + float off1 = (float) idx1; |
| 408 | + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { |
| 409 | + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); |
| 410 | + float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); |
| 411 | + var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); |
| 412 | + var scorer = supplier.scorer(); |
| 413 | + scorer.setScoringOrdinal(idx0); |
| 414 | + assertThat(scorer.score(idx1), equalTo(expected)); |
| 415 | + } |
| 416 | + } |
| 417 | + } |
| 418 | + } |
| 419 | + } |
| 420 | + |
| 421 | + public void testBulk() throws IOException { |
| 422 | + assumeTrue(notSupportedMsg(), supported()); |
| 423 | + var factory = AbstractVectorTestCase.factory.get(); |
| 424 | + |
| 425 | + final int dims = 1024; |
| 426 | + final int size = randomIntBetween(0, 102); |
| 427 | + // Set maxChunkSize to be less than dims * size |
| 428 | + try (Directory dir = new MMapDirectory(createTempDir("testBulk"))) { |
| 429 | + String fileName = "testBulk-" + dims; |
| 430 | + logger.info("Testing " + fileName); |
| 431 | + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { |
| 432 | + for (int i = 0; i < size; i++) { |
| 433 | + var vec = vector(i, dims); |
| 434 | + var off = (float) i; |
| 435 | + out.writeBytes(vec, 0, vec.length); |
| 436 | + out.writeInt(Float.floatToIntBits(off)); |
| 437 | + } |
| 438 | + } |
| 439 | + |
| 440 | + List<Integer> ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); |
| 441 | + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { |
| 442 | + for (int times = 0; times < TIMES; times++) { |
| 443 | + int idx0 = randomIntBetween(0, size - 1); |
| 444 | + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); |
| 445 | + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { |
| 446 | + QuantizedByteVectorValues values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); |
| 447 | + float[] expected = new float[size]; |
| 448 | + float[] scores = new float[size]; |
| 449 | + var referenceScorer = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(); |
| 450 | + referenceScorer.setScoringOrdinal(idx0); |
| 451 | + referenceScorer.bulkScore(nodes, expected, nodes.length); |
| 452 | + var supplier = factory |
| 453 | + .getInt7SQVectorScorerSupplier(sim, in, values, values.getScalarQuantizer().getConstantMultiplier()) |
| 454 | + .orElseThrow(); |
| 455 | + var testScorer = supplier.scorer(); |
| 456 | + testScorer.setScoringOrdinal(idx0); |
| 457 | + testScorer.bulkScore(nodes, scores, nodes.length); |
| 458 | + assertArrayEquals(expected, scores, 1e-6f); |
| 459 | + } |
| 460 | + } |
| 461 | + } |
| 462 | + } |
| 463 | + } |
| 464 | + |
| 465 | + // Test that the scorer works well when the IndexInput is greater than the directory segment chunk size. |
| 466 | + // For bulk this is especially important, as it tries to get a whole segment from IndexInput to pass it to |
| 467 | + // the native functions. |
| 468 | + public void testBulkWithDatasetGreaterThanChunkSize() throws IOException { |
| 469 | + assumeTrue(notSupportedMsg(), supported()); |
| 470 | + var factory = AbstractVectorTestCase.factory.get(); |
| 471 | + |
| 472 | + final int dims = 1024; |
| 473 | + final int size = 128; |
| 474 | + // Set maxChunkSize to be less than dims * size |
| 475 | + try (Directory dir = new MMapDirectory(createTempDir("testBulkWithDatasetGreaterThanChunkSize"), 8192)) { |
| 476 | + String fileName = "testBulkWithDatasetGreaterThanChunkSize-" + dims; |
| 477 | + logger.info("Testing " + fileName); |
| 478 | + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { |
| 479 | + for (int i = 0; i < size; i++) { |
| 480 | + var vec = vector(i, dims); |
| 481 | + var off = (float) i; |
| 482 | + out.writeBytes(vec, 0, vec.length); |
| 483 | + out.writeInt(Float.floatToIntBits(off)); |
| 484 | + } |
| 485 | + } |
| 486 | + |
| 487 | + List<Integer> ids = IntStream.range(0, size).boxed().collect(Collectors.toList()); |
| 488 | + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { |
| 489 | + for (int times = 0; times < TIMES; times++) { |
| 490 | + int idx0 = randomIntBetween(0, size - 1); |
| 491 | + int[] nodes = shuffledList(ids).stream().mapToInt(i -> i).toArray(); |
| 492 | + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { |
| 493 | + QuantizedByteVectorValues values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); |
| 494 | + float[] expected = new float[size]; |
| 495 | + float[] scores = new float[size]; |
| 496 | + var referenceScorer = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(); |
| 497 | + referenceScorer.setScoringOrdinal(idx0); |
| 498 | + referenceScorer.bulkScore(nodes, expected, nodes.length); |
| 499 | + var supplier = factory |
| 500 | + .getInt7SQVectorScorerSupplier(sim, in, values, values.getScalarQuantizer().getConstantMultiplier()) |
| 501 | + .orElseThrow(); |
| 502 | + var testScorer = supplier.scorer(); |
| 503 | + testScorer.setScoringOrdinal(idx0); |
| 504 | + testScorer.bulkScore(nodes, scores, nodes.length); |
| 505 | + assertArrayEquals(expected, scores, 1e-6f); |
| 506 | + } |
| 507 | + } |
| 508 | + } |
| 509 | + } |
| 510 | + } |
| 511 | + |
381 | 512 | public void testRace() throws Exception { |
382 | 513 | testRaceImpl(COSINE); |
383 | 514 | testRaceImpl(DOT_PRODUCT); |
@@ -474,7 +605,8 @@ public static float luceneScore( |
474 | 605 | return scorer.score(a, aOffsetValue, b, bOffsetValue); |
475 | 606 | } |
476 | 607 |
|
477 | | - RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException { |
| 608 | + static RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) |
| 609 | + throws IOException { |
478 | 610 | return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); |
479 | 611 | } |
480 | 612 |
|
|
0 commit comments