-
Notifications
You must be signed in to change notification settings - Fork 25.7k
ESQL - dense vector support cosine normalization #132721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
763fe63
80b48cf
40edca3
8bd7f79
f9447f7
7d2625c
968c73f
84f56ac
4371b69
c19db57
e98bbcb
0cb587a
e16c7cf
c9d7974
276f98c
c9b42dc
8936d33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,11 +30,14 @@ | |
| import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; | ||
| import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; | ||
| import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; | ||
| import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; | ||
| import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; | ||
| import org.elasticsearch.search.fetch.StoredFieldsSpec; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; | ||
|
|
||
| /** | ||
| * A reader that supports reading doc-values from a Lucene segment in Block fashion. | ||
| */ | ||
|
|
@@ -516,10 +519,12 @@ public String toString() { | |
| public static class DenseVectorBlockLoader extends DocValuesBlockLoader { | ||
| private final String fieldName; | ||
| private final int dimensions; | ||
| private final DenseVectorFieldMapper.DenseVectorFieldType fieldType; | ||
|
|
||
| public DenseVectorBlockLoader(String fieldName, int dimensions) { | ||
| public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) { | ||
| this.fieldName = fieldName; | ||
| this.dimensions = dimensions; | ||
| this.fieldType = fieldType; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -531,20 +536,27 @@ public Builder builder(BlockFactory factory, int expectedCount) { | |
| public AllReader reader(LeafReaderContext context) throws IOException { | ||
| FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); | ||
| if (floatVectorValues != null) { | ||
| return new DenseVectorValuesBlockReader(floatVectorValues, dimensions); | ||
| if (fieldType.isNormalized()) { | ||
| NumericDocValues magnitudeDocValues = context.reader() | ||
| .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); | ||
| return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); | ||
| } | ||
| return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); | ||
| } | ||
|
|
||
| return new ConstantNullsReader(); | ||
| } | ||
| } | ||
|
|
||
| private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { | ||
| private final FloatVectorValues floatVectorValues; | ||
| private final KnnVectorValues.DocIndexIterator iterator; | ||
| private final int dimensions; | ||
| private abstract static class DenseVectorValuesBlockReader<T extends KnnVectorValues> extends BlockDocValuesReader { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using an abstract class is previous work for supporting byte element type. |
||
|
|
||
| DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { | ||
| this.floatVectorValues = floatVectorValues; | ||
| iterator = floatVectorValues.iterator(); | ||
| protected final T vectorValues; | ||
| protected final KnnVectorValues.DocIndexIterator iterator; | ||
| protected final int dimensions; | ||
|
|
||
| DenseVectorValuesBlockReader(T vectorValues, int dimensions) { | ||
| this.vectorValues = vectorValues; | ||
| iterator = vectorValues.iterator(); | ||
| this.dimensions = dimensions; | ||
| } | ||
|
|
||
|
|
@@ -569,26 +581,74 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException | |
| builder.appendNull(); | ||
| } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { | ||
| builder.beginPositionEntry(); | ||
| float[] floats = floatVectorValues.vectorValue(iterator.index()); | ||
| assert floats.length == dimensions | ||
| : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; | ||
| for (float aFloat : floats) { | ||
| builder.appendFloat(aFloat); | ||
| } | ||
| appendDoc(builder); | ||
| builder.endPositionEntry(); | ||
| } else { | ||
| builder.appendNull(); | ||
| } | ||
| } | ||
|
|
||
| protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException; | ||
|
|
||
| @Override | ||
| public int docId() { | ||
| return iterator.docID(); | ||
| } | ||
| } | ||
|
|
||
| private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> { | ||
|
|
||
| FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { | ||
| super(floatVectorValues, dimensions); | ||
| } | ||
|
|
||
| protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { | ||
| float[] floats = vectorValues.vectorValue(iterator.index()); | ||
| assert floats.length == dimensions | ||
| : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; | ||
| for (float aFloat : floats) { | ||
| builder.appendFloat(aFloat); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader"; | ||
| } | ||
| } | ||
|
|
||
| private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> { | ||
| private final NumericDocValues magnitudeDocValues; | ||
|
|
||
| FloatDenseVectorNormalizedValuesBlockReader( | ||
| FloatVectorValues floatVectorValues, | ||
| int dimensions, | ||
| NumericDocValues magnitudeDocValues | ||
| ) { | ||
| super(floatVectorValues, dimensions); | ||
| this.magnitudeDocValues = magnitudeDocValues; | ||
| } | ||
|
|
||
| @Override | ||
| protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { | ||
| float[] floats = vectorValues.vectorValue(iterator.index()); | ||
|
||
| assert floats.length == dimensions | ||
| : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; | ||
|
|
||
| float magnitude = 1.0f; | ||
| // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a | ||
| // stored magnitude for all docs | ||
| if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { | ||
| magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); | ||
| } | ||
| for (float aFloat : floats) { | ||
| builder.appendFloat(aFloat * magnitude); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return "BlockDocValuesReader.FloatVectorValuesBlockReader"; | ||
| return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp | |
| this, | ||
| denseVectorFieldType.dims, | ||
| denseVectorFieldType.indexed, | ||
| denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) | ||
| && denseVectorFieldType.indexed | ||
| && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { | ||
| @Override | ||
| public CacheHelper getCoreCacheHelper() { | ||
| return r.getCoreCacheHelper(); | ||
| } | ||
| denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used the |
||
| @Override | ||
| public CacheHelper getCoreCacheHelper() { | ||
| return r.getCoreCacheHelper(); | ||
| } | ||
|
|
||
| @Override | ||
| public CacheHelper getReaderCacheHelper() { | ||
| return r.getReaderCacheHelper(); | ||
| } | ||
| @Override | ||
| public CacheHelper getReaderCacheHelper() { | ||
| return r.getReaderCacheHelper(); | ||
| } | ||
|
|
||
| @Override | ||
| public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { | ||
| FloatVectorValues values = in.getFloatVectorValues(fieldName); | ||
| if (values == null) { | ||
| return null; | ||
| } | ||
| return new DenormalizedCosineFloatVectorValues( | ||
| values, | ||
| in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) | ||
| ); | ||
| @Override | ||
| public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { | ||
| FloatVectorValues values = in.getFloatVectorValues(fieldName); | ||
| if (values == null) { | ||
| return null; | ||
| } | ||
| } : r -> r | ||
| return new DenormalizedCosineFloatVectorValues( | ||
| values, | ||
| in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) | ||
| ); | ||
| } | ||
| } : r -> r | ||
| ); | ||
| } | ||
|
|
||
|
|
@@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie | |
| fieldMapper.checkDimensionMatches(index, context); | ||
| checkVectorBounds(vector); | ||
| checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); | ||
| if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) | ||
| && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) | ||
| && isNotUnitVector(squaredMagnitude)) { | ||
| if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { | ||
| float length = (float) Math.sqrt(squaredMagnitude); | ||
| for (int i = 0; i < vector.length; i++) { | ||
| vector[i] /= length; | ||
|
|
@@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) | |
| return knnQuery; | ||
| } | ||
|
|
||
| public boolean isNormalized() { | ||
| return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity); | ||
| } | ||
|
|
||
| private Query createExactKnnBitQuery(byte[] queryVector) { | ||
| elementType.checkDimensions(dims, queryVector.length); | ||
| return new DenseVectorQuery.Bytes(queryVector, name()); | ||
|
|
@@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { | |
| if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { | ||
| float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); | ||
| elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); | ||
| if (similarity == VectorSimilarity.COSINE | ||
| && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) | ||
| && isNotUnitVector(squaredMagnitude)) { | ||
| if (isNormalized() && isNotUnitVector(squaredMagnitude)) { | ||
| float length = (float) Math.sqrt(squaredMagnitude); | ||
| queryVector = Arrays.copyOf(queryVector, queryVector.length); | ||
| for (int i = 0; i < queryVector.length; i++) { | ||
|
|
@@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery( | |
| if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { | ||
| float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); | ||
| elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); | ||
| if (similarity == VectorSimilarity.COSINE | ||
| && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) | ||
| && isNotUnitVector(squaredMagnitude)) { | ||
| if (isNormalized() && isNotUnitVector(squaredMagnitude)) { | ||
| float length = (float) Math.sqrt(squaredMagnitude); | ||
| queryVector = Arrays.copyOf(queryVector, queryVector.length); | ||
| for (int i = 0; i < queryVector.length; i++) { | ||
|
|
@@ -2795,7 +2791,7 @@ int getVectorDimensions() { | |
| return dims; | ||
| } | ||
|
|
||
| ElementType getElementType() { | ||
| public ElementType getElementType() { | ||
| return elementType; | ||
| } | ||
|
|
||
|
|
@@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { | |
| } | ||
|
|
||
| if (indexed) { | ||
| return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); | ||
| return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this); | ||
| } | ||
|
|
||
| if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing the DenseVectorFieldType looked better than including other specific attributes needed for normalization (fieldType.name(), IndexVersion). I opted for creating a
isNormalized()method in the DenseVectorFieldType instead.