Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat,
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;

provides org.apache.lucene.codecs.Codec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand Down Expand Up @@ -250,7 +249,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI
final IndexInput finalVectorDataInput = vectorDataInput;
final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
finalVectorDataInput,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.es93;

import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat {

static final String NAME = "ES93HnswVectorsFormat";

private final FlatVectorsFormat flatVectorsFormat;

public ES93HnswVectorsFormat() {
super(NAME);
flatVectorsFormat = new ES93GenericFlatVectorsFormat();
}

public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) {
super(NAME, maxConn, beamWidth);
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
}

public ES93HnswVectorsFormat(
int maxConn,
int beamWidth,
boolean bfloat16,
boolean useDirectIO,
int numMergeWorkers,
ExecutorService mergeExec
) {
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
}

@Override
protected FlatVectorsFormat flatVectorsFormat() {
return flatVectorsFormat;
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsForma
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.es93;

import org.apache.lucene.index.VectorEncoding;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.hamcrest.Matchers.closeTo;

public class ES93HnswBFloat16VectorsFormatTests extends ES93HnswVectorsFormatTests {

@Override
protected boolean useBFloat16() {
return true;
}

@Override
protected VectorEncoding randomVectorEncoding() {
return VectorEncoding.FLOAT32;
}

@Override
public void testEmptyByteVectorData() throws Exception {
// no bytes
}

@Override
public void testMergingWithDifferentByteKnnFields() throws Exception {
// no bytes
}

@Override
public void testByteVectorScorerIteration() throws Exception {
// no bytes
}

@Override
public void testSortedIndexBytes() throws Exception {
// no bytes
}

@Override
public void testMismatchedFields() throws Exception {
// no bytes
}

@Override
public void testRandomBytes() throws Exception {
// no bytes
}

@Override
public void testRandom() throws Exception {
AssertionError err = expectThrows(AssertionError.class, super::testRandom);
assertFloatsWithinBounds(err);
}

@Override
public void testRandomWithUpdatesAndGraph() throws Exception {
AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph);
assertFloatsWithinBounds(err);
}

@Override
public void testSparseVectors() throws Exception {
AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors);
assertFloatsWithinBounds(err);
}

@Override
public void testVectorValuesReportCorrectDocs() throws Exception {
AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs);
assertFloatsWithinBounds(err);
}

private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>");

private static void assertFloatsWithinBounds(AssertionError error) {
Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage());
if (m.matches() == false) {
throw error; // nothing to do with us, just rethrow
}

// numbers just need to be in the same vicinity
double expected = Double.parseDouble(m.group(1));
double actual = Double.parseDouble(m.group(2));
double allowedError = expected * 0.01; // within 1%
assertThat(error.getMessage(), actual, closeTo(expected, allowedError));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.es93;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.SameThreadExecutorService;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.BFloat16;

import java.io.IOException;
import java.util.Locale;

import static java.lang.String.format;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;

public class ES93HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase {

static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

private KnnVectorsFormat format;

protected boolean useBFloat16() {
return false;
}

@Override
public void setUp() throws Exception {
format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean());
super.setUp();
}

@Override
protected Codec getCodec() {
return TestUtil.alwaysKnnVectorsFormat(format);
}

public void testToString() {
FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new ES93HnswVectorsFormat(10, 20, false, false);
}
};
String expectedPattern = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20,"
+ " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat,"
+ " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}

public void testLimits() {
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, false, false));
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, false, false));
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, false, false));
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, false, false));
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, false, false));
expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false));
expectThrows(
IllegalArgumentException.class,
() -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService())
);
}

public void testSimpleOffHeapSize() throws IOException {
float[] vector = randomVector(random().nextInt(12, 500));
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT));
w.addDocument(doc);
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
if (r instanceof CodecReader codecReader) {
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
knnVectorsReader = fieldsReader.getFieldReader("f");
}
var fieldInfo = r.getFieldInfos().fieldInfo("f");
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES;
assertEquals(vector.length * bytes, (long) offHeap.get("vec"));
assertEquals(1L, (long) offHeap.get("vex"));
assertEquals(2, offHeap.size());
}
}
}
}
}