3030import org .apache .lucene .index .SegmentWriteState ;
3131import org .apache .lucene .index .VectorEncoding ;
3232import org .apache .lucene .index .VectorSimilarityFunction ;
33+ import org .apache .lucene .queries .function .FunctionScoreQuery ;
3334import org .apache .lucene .search .FieldExistsQuery ;
3435import org .apache .lucene .search .Query ;
3536import org .apache .lucene .search .join .BitSetProducer ;
@@ -121,6 +122,8 @@ public static boolean isNotUnitVector(float magnitude) {
121122 public static short MAX_DIMS_COUNT = 4096 ; // maximum allowed number of dimensions
122123 public static int MAX_DIMS_COUNT_BIT = 4096 * Byte .SIZE ; // maximum allowed number of dimensions
123124
125+ public static final int OVERSAMPLE_LIMIT = 10_000 ; // Max oversample allowed for k and num_candidates
126+
124127 public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128 ; // minimum number of dims for floats to be dynamically mapped to vector
125128 public static final int MAGNITUDE_BYTES = 4 ;
126129
@@ -2000,6 +2003,7 @@ public Query createKnnQuery(
20002003 VectorData queryVector ,
20012004 Integer k ,
20022005 int numCands ,
2006+ Float rescoreOversample ,
20032007 Query filter ,
20042008 Float similarityThreshold ,
20052009 BitSetProducer parentFilter
@@ -2010,21 +2014,50 @@ public Query createKnnQuery(
20102014 );
20112015 }
20122016 return switch (getElementType ()) {
2013- case BYTE -> createKnnByteQuery (queryVector .asByteVector (), k , numCands , filter , similarityThreshold , parentFilter );
2014- case FLOAT -> createKnnFloatQuery (queryVector .asFloatVector (), k , numCands , filter , similarityThreshold , parentFilter );
2015- case BIT -> createKnnBitQuery (queryVector .asByteVector (), k , numCands , filter , similarityThreshold , parentFilter );
2017+ case BYTE -> createKnnByteQuery (
2018+ queryVector .asByteVector (),
2019+ k ,
2020+ numCands ,
2021+ filter ,
2022+ rescoreOversample ,
2023+ similarityThreshold ,
2024+ parentFilter
2025+ );
2026+ case FLOAT -> createKnnFloatQuery (
2027+ queryVector .asFloatVector (),
2028+ k ,
2029+ numCands ,
2030+ rescoreOversample ,
2031+ filter ,
2032+ similarityThreshold ,
2033+ parentFilter
2034+ );
2035+ case BIT -> createKnnBitQuery (
2036+ queryVector .asByteVector (),
2037+ k ,
2038+ numCands ,
2039+ rescoreOversample ,
2040+ filter ,
2041+ similarityThreshold ,
2042+ parentFilter
2043+ );
20162044 };
20172045 }
20182046
20192047 private Query createKnnBitQuery (
20202048 byte [] queryVector ,
20212049 Integer k ,
20222050 int numCands ,
2051+ Float rescoreOversample ,
20232052 Query filter ,
20242053 Float similarityThreshold ,
20252054 BitSetProducer parentFilter
20262055 ) {
20272056 elementType .checkDimensions (dims , queryVector .length );
2057+ if (similarity == VectorSimilarity .DOT_PRODUCT || similarity == VectorSimilarity .COSINE ) {
2058+ float squaredMagnitude = VectorUtil .dotProduct (queryVector , queryVector );
2059+ elementType .checkVectorMagnitude (similarity , ElementType .errorByteElementsAppender (queryVector ), squaredMagnitude );
2060+ }
20282061 Query knnQuery = parentFilter != null
20292062 ? new ESDiversifyingChildrenByteKnnVectorQuery (name (), queryVector , filter , k , numCands , parentFilter )
20302063 : new ESKnnByteVectorQuery (name (), queryVector , k , numCands , filter );
@@ -2035,6 +2068,17 @@ private Query createKnnBitQuery(
20352068 similarity .score (similarityThreshold , elementType , dims )
20362069 );
20372070 }
2071+ if (rescoreOversample != null ) {
2072+ knnQuery = new FunctionScoreQuery (
2073+ knnQuery ,
2074+ new VectorSimilarityByteValueSource (
2075+ name (),
2076+ queryVector ,
2077+ similarity .vectorSimilarityFunction (indexVersionCreated , ElementType .BYTE )
2078+ )
2079+ );
2080+
2081+ }
20382082 return knnQuery ;
20392083 }
20402084
@@ -2043,6 +2087,7 @@ private Query createKnnByteQuery(
20432087 Integer k ,
20442088 int numCands ,
20452089 Query filter ,
2090+ Float rescoreOversample ,
20462091 Float similarityThreshold ,
20472092 BitSetProducer parentFilter
20482093 ) {
@@ -2052,23 +2097,38 @@ private Query createKnnByteQuery(
20522097 float squaredMagnitude = VectorUtil .dotProduct (queryVector , queryVector );
20532098 elementType .checkVectorMagnitude (similarity , ElementType .errorByteElementsAppender (queryVector ), squaredMagnitude );
20542099 }
2100+ int adjustedK = rescoreOversample == null ? k : Math .min (OVERSAMPLE_LIMIT , (int ) Math .ceil (k * rescoreOversample ));
2101+ int adjustedNumCands = Math .max (adjustedK , numCands );
2102+
20552103 Query knnQuery = parentFilter != null
2056- ? new ESDiversifyingChildrenByteKnnVectorQuery (name (), queryVector , filter , k , numCands , parentFilter )
2057- : new ESKnnByteVectorQuery (name (), queryVector , k , numCands , filter );
2104+ ? new ESDiversifyingChildrenByteKnnVectorQuery (name (), queryVector , filter , adjustedK , adjustedNumCands , parentFilter )
2105+ : new ESKnnByteVectorQuery (name (), queryVector , adjustedK , adjustedNumCands , filter );
20582106 if (similarityThreshold != null ) {
20592107 knnQuery = new VectorSimilarityQuery (
20602108 knnQuery ,
20612109 similarityThreshold ,
20622110 similarity .score (similarityThreshold , elementType , dims )
20632111 );
20642112 }
2113+ if (rescoreOversample != null ) {
2114+ knnQuery = new FunctionScoreQuery (
2115+ knnQuery ,
2116+ new VectorSimilarityByteValueSource (
2117+ name (),
2118+ queryVector ,
2119+ similarity .vectorSimilarityFunction (indexVersionCreated , ElementType .BYTE )
2120+ )
2121+ );
2122+
2123+ }
20652124 return knnQuery ;
20662125 }
20672126
20682127 private Query createKnnFloatQuery (
20692128 float [] queryVector ,
20702129 Integer k ,
20712130 int numCands ,
2131+ Float rescoreOversample ,
20722132 Query filter ,
20732133 Float similarityThreshold ,
20742134 BitSetProducer parentFilter
@@ -2088,16 +2148,30 @@ && isNotUnitVector(squaredMagnitude)) {
20882148 }
20892149 }
20902150 }
2151+
2152+ int adjustedK = rescoreOversample == null ? k : Math .min (OVERSAMPLE_LIMIT , (int ) Math .ceil (k * rescoreOversample ));
2153+ int adjustedNumCands = Math .max (adjustedK , numCands );
20912154 Query knnQuery = parentFilter != null
2092- ? new ESDiversifyingChildrenFloatKnnVectorQuery (name (), queryVector , filter , k , numCands , parentFilter )
2093- : new ESKnnFloatVectorQuery (name (), queryVector , k , numCands , filter );
2155+ ? new ESDiversifyingChildrenFloatKnnVectorQuery (name (), queryVector , filter , adjustedK , adjustedNumCands , parentFilter )
2156+ : new ESKnnFloatVectorQuery (name (), queryVector , adjustedK , adjustedNumCands , filter );
20942157 if (similarityThreshold != null ) {
20952158 knnQuery = new VectorSimilarityQuery (
20962159 knnQuery ,
20972160 similarityThreshold ,
20982161 similarity .score (similarityThreshold , elementType , dims )
20992162 );
21002163 }
2164+ if (rescoreOversample != null ) {
2165+ knnQuery = new FunctionScoreQuery (
2166+ knnQuery ,
2167+ new VectorSimilarityFloatValueSource (
2168+ name (),
2169+ queryVector ,
2170+ similarity .vectorSimilarityFunction (indexVersionCreated , ElementType .FLOAT )
2171+ )
2172+ );
2173+
2174+ }
21012175 return knnQuery ;
21022176 }
21032177
0 commit comments