Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,18 @@ public ES93HnswBinaryQuantizedVectorsFormat() {
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param useDirectIO whether to use direct IO when reading raw vectors
*/
public ES93HnswBinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
super(NAME);
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO);
}

/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param useDirectIO whether to use direct IO when reading raw vectors
*/
public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) {
Expand All @@ -70,8 +80,8 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean
public ES93HnswBinaryQuantizedVectorsFormat(
int maxConn,
int beamWidth,
boolean useDirectIO,
boolean useBFloat16,
boolean useDirectIO,
int numMergeWorkers,
ExecutorService mergeExec
) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.SameThreadExecutorService;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.logging.LogConfigurator;
import org.hamcrest.Matcher;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

public abstract class BaseHnswVectorsFormatTestCase extends BaseKnnVectorsFormatTestCase {

static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

protected abstract KnnVectorsFormat createFormat();

protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth);

protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service);

private KnnVectorsFormat format;

@Override
public void setUp() throws Exception {
format = createFormat();
super.setUp();
}

@Override
protected Codec getCodec() {
return TestUtil.alwaysKnnVectorsFormat(format);
}

public void testLimits() {
expectThrows(IllegalArgumentException.class, () -> createFormat(-1, 20));
expectThrows(IllegalArgumentException.class, () -> createFormat(0, 20));
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 0));
expectThrows(IllegalArgumentException.class, () -> createFormat(20, -1));
expectThrows(IllegalArgumentException.class, () -> createFormat(512 + 1, 20));
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 3201));
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 100, 1, new SameThreadExecutorService()));
}

public void testSingleVectorCase() throws Exception {
float[] vector = randomVector(random().nextInt(12, 500));
for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
if (similarityFunction == VectorSimilarityFunction.COSINE) {
VectorUtil.l2normalize(vector);
}
doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
w.addDocument(doc);
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
FloatVectorValues vectorValues = r.getFloatVectorValues("f");
KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
assertThat(vectorValues.size(), equalTo(1));
while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
}
float[] randomVector = randomVector(vector.length);
if (similarityFunction == VectorSimilarityFunction.COSINE) {
VectorUtil.l2normalize(randomVector);
}
float trueScore = similarityFunction.compare(vector, randomVector);
TopDocs td = r.searchNearestVectors(
"f",
randomVector,
1,
AcceptDocs.fromLiveDocs(r.getLiveDocs(), r.maxDoc()),
Integer.MAX_VALUE
);
assertEquals(1, td.totalHits.value());
assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f));
// When it's the only vector in a segment, the score should be very close to the true score
assertEquals(trueScore, td.scoreDocs[0].score, 0.01f);
}
}
}
}

protected static void testSimpleOffHeapSize(
Directory dir,
IndexWriterConfig config,
float[] vector,
Matcher<? super Map<String, Long>> matchesMap
) throws IOException {
try (IndexWriter w = new IndexWriter(dir, config)) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT));
w.addDocument(doc);
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
if (r instanceof CodecReader codecReader) {
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
knnVectorsReader = fieldsReader.getFieldReader("f");
}
var fieldInfo = r.getFieldInfos().fieldInfo("f");
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
assertThat(offHeap, matchesMap);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@

import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasEntry;

// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 50) // tests.directory sys property?
public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {

static {
Expand Down Expand Up @@ -178,7 +182,7 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exc
AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()),
100
);
assertEquals(hits.scoreDocs.length, 3);
assertThat(hits.scoreDocs, arrayWithSize(3));
assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id"));
assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id"));
assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id"));
Expand All @@ -202,10 +206,11 @@ public void testSimpleOffHeapSize() throws IOException {
}
var fieldInfo = r.getFieldInfos().fieldInfo("f");
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
assertEquals(3, offHeap.size());
assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec"));
assertEquals(1L, (long) offHeap.get("vex"));
assertTrue(offHeap.get("veq") > 0L);

assertThat(offHeap, aMapWithSize(3));
assertThat(offHeap, hasEntry("vex", 1L));
assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L)));
assertThat(offHeap, hasEntry("vec", (long) vector.length * Float.BYTES));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@

import java.io.IOException;

import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasEntry;

public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {

static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES815HnswBitVectorsFormat());
Expand Down Expand Up @@ -56,9 +61,10 @@ public void testSimpleOffHeapSize() throws IOException {
}
var fieldInfo = r.getFieldInfos().fieldInfo("f");
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
assertEquals(2, offHeap.size());
assertTrue(offHeap.get("vec") > 0L);
assertEquals(1L, (long) offHeap.get("vex"));

assertThat(offHeap, aMapWithSize(2));
assertThat(offHeap, hasEntry("vex", 1L));
assertThat(offHeap, hasEntry(equalTo("vec"), greaterThan(0L)));
}
}
}
Expand Down
Loading