Skip to content

Commit 7a34391

Browse files
carlosdelestelasticsearchmachine
andauthored
ESQL - dense vector support cosine normalization (#132721)
* Take into account normalization for dense vector support * Fix cherry pick * [CI] Auto commit changes from spotless * Remove debugging code * Check that we may not have magnitudes at all, or for normalized vectors * Better parameterized test * Refactor dimension check * Refactor index names * Remove comment * Fix test --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 0165233 commit 7a34391

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import java.io.IOException;
4040

41+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
4142
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
4243

4344
/**
@@ -540,6 +541,11 @@ public AllReader reader(LeafReaderContext context) throws IOException {
540541
case FLOAT -> {
541542
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
542543
if (floatVectorValues != null) {
544+
if (fieldType.isNormalized()) {
545+
NumericDocValues magnitudeDocValues = context.reader()
546+
.getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX);
547+
return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues);
548+
}
543549
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
544550
}
545551
}
@@ -584,6 +590,9 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build
584590
}
585591

586592
private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
593+
assert vectorValues.dimension() == dimensions
594+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + vectorValues.dimension();
595+
587596
if (iterator.docID() > doc) {
588597
builder.appendNull();
589598
} else if (iterator.docID() == doc || iterator.advance(doc) == doc) {
@@ -611,8 +620,6 @@ private static class FloatDenseVectorValuesBlockReader extends DenseVectorValues
611620

612621
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
613622
float[] floats = vectorValues.vectorValue(iterator.index());
614-
assert floats.length == dimensions
615-
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
616623
for (float aFloat : floats) {
617624
builder.appendFloat(aFloat);
618625
}
@@ -624,15 +631,45 @@ public String toString() {
624631
}
625632
}
626633

634+
private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
635+
private final NumericDocValues magnitudeDocValues;
636+
637+
FloatDenseVectorNormalizedValuesBlockReader(
638+
FloatVectorValues floatVectorValues,
639+
int dimensions,
640+
NumericDocValues magnitudeDocValues
641+
) {
642+
super(floatVectorValues, dimensions);
643+
this.magnitudeDocValues = magnitudeDocValues;
644+
}
645+
646+
@Override
647+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
648+
float magnitude = 1.0f;
649+
// If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a
650+
// stored magnitude for all docs
651+
if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) {
652+
magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
653+
}
654+
float[] floats = vectorValues.vectorValue(iterator.index());
655+
for (float aFloat : floats) {
656+
builder.appendFloat(aFloat * magnitude);
657+
}
658+
}
659+
660+
@Override
661+
public String toString() {
662+
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
663+
}
664+
}
665+
627666
private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
628667
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
629668
super(floatVectorValues, dimensions);
630669
}
631670

632671
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
633672
byte[] bytes = vectorValues.vectorValue(iterator.index());
634-
assert bytes.length == dimensions
635-
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length;
636673
for (byte aFloat : bytes) {
637674
builder.appendFloat(aFloat);
638675
}

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,30 @@
2424

2525
import java.io.IOException;
2626
import java.util.ArrayList;
27+
import java.util.Arrays;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Locale;
3031
import java.util.Map;
3132
import java.util.Set;
33+
import java.util.stream.Collectors;
3234

3335
import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING;
3436
import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC;
3537
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
3638

3739
public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
3840

39-
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Set.of(
40-
"int8_hnsw",
41-
"hnsw",
42-
"int4_hnsw",
43-
"bbq_hnsw",
44-
"int8_flat",
45-
"int4_flat",
46-
"bbq_flat",
47-
"flat"
48-
);
49-
public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat");
41+
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
42+
.filter(DenseVectorFieldMapper.VectorIndexType::isEnabled)
43+
.map(v -> v.getName().toLowerCase(Locale.ROOT))
44+
.collect(Collectors.toSet());
45+
46+
public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
47+
.filter(t -> t.isEnabled() && t.isQuantized() == false)
48+
.map(v -> v.getName().toLowerCase(Locale.ROOT))
49+
.collect(Collectors.toSet());
50+
5051
public static final float DELTA = 1e-7F;
5152

5253
private final ElementType elementType;
@@ -57,15 +58,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
5758
@ParametersFactory
5859
public static Iterable<Object[]> parameters() throws Exception {
5960
List<Object[]> params = new ArrayList<>();
60-
List<DenseVectorFieldMapper.VectorSimilarity> similarities = List.of(
61-
DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT,
62-
DenseVectorFieldMapper.VectorSimilarity.L2_NORM,
63-
DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT
64-
);
6561

6662
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) {
67-
// Test all similarities for element types
68-
for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) {
63+
// Test all similarities
64+
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
6965
params.add(new Object[] { elementType, similarity, true, false });
7066
}
7167

@@ -74,6 +70,7 @@ public static Iterable<Object[]> parameters() throws Exception {
7470
// No indexing, synthetic source
7571
params.add(new Object[] { elementType, null, false, true });
7672
}
73+
7774
return params;
7875
}
7976

0 commit comments

Comments
 (0)