diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 3e85ef79d2e5f..f22e50e391382 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -37,6 +37,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexVersion; @@ -96,6 +97,7 @@ import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.index.IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; +import static org.elasticsearch.search.vectors.KnnSearchBuilder.NUM_CANDS_FIELD; /** * A {@link FieldMapper} for indexing a dense vector of floats. @@ -104,6 +106,7 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; private static final float EPS = 1e-3f; public static final int BBQ_MIN_DIMS = 64; + public static final int NUM_CANDS_LIMIT = 10_000; public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; @@ -125,7 +128,10 @@ public static boolean isNotUnitVector(float magnitude) { public static final short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to // vector public static final int MAGNITUDE_BYTES = 4; - public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed + + public static final String M_FIELD = "m"; + public static final String EF_CONSTRUCTION_FIELD = "ef_construction"; + public static final String MAX_SEARCH_EF_FIELD = "max_search_ef"; private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -215,6 +221,7 @@ public Builder(String name, IndexVersion indexVersionCreated) { ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + NUM_CANDS_LIMIT, null, null ) @@ -234,6 +241,9 @@ public Builder(String name, IndexVersion indexVersionCreated) { if (v != null) { v.validateElementType(elementType.getValue()); } + if (v != null) { + v.validateNumCandidates(); + } }) .acceptsNull() .setMergeValidator( @@ -1218,9 +1228,11 @@ public final String toString() { abstract static class IndexOptions implements ToXContent { final VectorIndexType type; + final RescoreVector rescoreVector; - IndexOptions(VectorIndexType type) { + IndexOptions(VectorIndexType type, RescoreVector rescoreVector) { this.type = type; + this.rescoreVector = rescoreVector; } abstract KnnVectorsFormat getVectorsFormat(ElementType elementType); @@ -1242,6 +1254,28 @@ public void validateDimension(int dim) { throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); } + public void validateNumCandidates() { + + } + + public int maxSearchEf() { + return Integer.MAX_VALUE; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("type", type); + innerXContent(builder, params); + if (rescoreVector != null) { + rescoreVector.toXContent(builder, params); + } + builder.endObject(); + return builder; + } + + abstract public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException; + abstract boolean doEquals(IndexOptions other); abstract int doHashCode(); @@ -1255,40 +1289,98 @@ public final boolean equals(Object other) { return false; } IndexOptions otherOptions = (IndexOptions) other; - return Objects.equals(type, otherOptions.type) && doEquals(otherOptions); + return Objects.equals(type, otherOptions.type) + && Objects.equals(rescoreVector, otherOptions.rescoreVector) + && doEquals(otherOptions); } @Override public final int hashCode() { - return Objects.hash(type, doHashCode()); + return Objects.hash(type, rescoreVector, doHashCode()); + } + + // toString + public String toString() { + return Strings.toString(this); } } - abstract static class QuantizedIndexOptions extends IndexOptions { - final RescoreVector rescoreVector; + abstract static class AbstractHnswIndexOptions extends IndexOptions { + protected final int m; + protected final int efConstruction; + protected final int maxSearchEf; - QuantizedIndexOptions(VectorIndexType type, RescoreVector rescoreVector) { - super(type); - this.rescoreVector = rescoreVector; + AbstractHnswIndexOptions(VectorIndexType type, int m, int efConstruction, int maxSearchEf, RescoreVector rescoreVector) { + super(type, rescoreVector); + this.m = m; + this.efConstruction = efConstruction; + this.maxSearchEf = maxSearchEf; + } + + @Override + public void validateNumCandidates() { + if (maxSearchEf <= 0) { + throw new IllegalArgumentException("[" + MAX_SEARCH_EF_FIELD + "] must be greater than 0"); + } + } + + @Override + public int maxSearchEf() { + return maxSearchEf; } + + @Override + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("m", m); + builder.field("ef_construction", efConstruction); + builder.field("max_search_ef", maxSearchEf); + innerHnswXContent(builder, params); + return builder; + } + + abstract public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException; + + @Override + public boolean doEquals(IndexOptions o) { + AbstractHnswIndexOptions other = (AbstractHnswIndexOptions) o; + return Objects.equals(m, other.m) + && Objects.equals(efConstruction, other.efConstruction) + && Objects.equals(maxSearchEf, other.maxSearchEf) + && doChildEquals(other); + } + + @Override + public int doHashCode() { + return Objects.hash(type, rescoreVector, doChildHashCode()); + } + + abstract int doChildHashCode(); + + abstract boolean doChildEquals(IndexOptions other); } public enum VectorIndexType { HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); + MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new HnswIndexOptions(m, efConstruction); + return new HnswIndexOptions(m, efConstruction, maxSearchEf); } @Override @@ -1304,8 +1396,9 @@ public boolean supportsDimension(int dims) { INT8_HNSW("int8_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1313,8 +1406,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); @@ -1324,7 +1421,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); + return new Int8HnswIndexOptions(m, efConstruction, maxSearchEf, confidenceInterval, rescoreVector); } @Override @@ -1339,8 +1436,9 @@ public boolean supportsDimension(int dims) { }, INT4_HNSW("int4_hnsw", true) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1348,8 +1446,12 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); @@ -1359,7 +1461,7 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval, rescoreVector); + return new Int4HnswIndexOptions(m, efConstruction, maxSearchEf, confidenceInterval, rescoreVector); } @Override @@ -1444,22 +1546,27 @@ public boolean supportsDimension(int dims) { BBQ_HNSW("bbq_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap, IndexVersion indexVersion) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); RescoreVector rescoreVector = null; if (indexVersion.onOrAfter(ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)) { rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction, rescoreVector); + return new BBQHnswIndexOptions(m, efConstruction, maxSearchEf, rescoreVector); } @Override @@ -1522,7 +1629,7 @@ public String toString() { } } - static class Int8FlatIndexOptions extends QuantizedIndexOptions { + static class Int8FlatIndexOptions extends IndexOptions { private final Float confidenceInterval; Int8FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { @@ -1531,16 +1638,13 @@ static class Int8FlatIndexOptions extends QuantizedIndexOptions { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } - builder.endObject(); return builder; } @@ -1553,12 +1657,12 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { @Override boolean doEquals(IndexOptions o) { Int8FlatIndexOptions that = (Int8FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); + return Objects.equals(confidenceInterval, that.confidenceInterval); } @Override int doHashCode() { - return Objects.hash(confidenceInterval, rescoreVector); + return Objects.hash(confidenceInterval); } @Override @@ -1574,14 +1678,11 @@ boolean updatableTo(IndexOptions update) { static class FlatIndexOptions extends IndexOptions { FlatIndexOptions() { - super(VectorIndexType.FLAT); + super(VectorIndexType.FLAT, null); } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.endObject(); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { return builder; } @@ -1609,15 +1710,11 @@ public int doHashCode() { } } - static class Int4HnswIndexOptions extends QuantizedIndexOptions { - private final int m; - private final int efConstruction; + static class Int4HnswIndexOptions extends AbstractHnswIndexOptions { private final float confidenceInterval; - Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { - super(VectorIndexType.INT4_HNSW, rescoreVector); - this.m = m; - this.efConstruction = efConstruction; + Int4HnswIndexOptions(int m, int efConstruction, int maxSearchEf, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT4_HNSW, m, efConstruction, maxSearchEf, rescoreVector); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1630,46 +1727,20 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { builder.field("confidence_interval", confidenceInterval); - if (rescoreVector != null) { - rescoreVector.toXContent(builder, params); - } - builder.endObject(); return builder; } @Override - public boolean doEquals(IndexOptions o) { + boolean doChildEquals(IndexOptions o) { Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; - return m == that.m - && efConstruction == that.efConstruction - && Objects.equals(confidenceInterval, that.confidenceInterval) - && Objects.equals(rescoreVector, that.rescoreVector); - } - - @Override - public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); + return Objects.equals(confidenceInterval, that.confidenceInterval); } @Override - public String toString() { - return "{type=" - + type - + ", m=" - + m - + ", ef_construction=" - + efConstruction - + ", confidence_interval=" - + confidenceInterval - + ", rescore_vector=" - + (rescoreVector == null ? "none" : rescoreVector) - + "}"; + int doChildHashCode() { + return Objects.hash(confidenceInterval); } @Override @@ -1685,7 +1756,7 @@ boolean updatableTo(IndexOptions update) { } } - static class Int4FlatIndexOptions extends QuantizedIndexOptions { + static class Int4FlatIndexOptions extends IndexOptions { private final float confidenceInterval; Int4FlatIndexOptions(Float confidenceInterval, RescoreVector rescoreVector) { @@ -1702,14 +1773,8 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { builder.field("confidence_interval", confidenceInterval); - if (rescoreVector != null) { - rescoreVector.toXContent(builder, params); - } - builder.endObject(); return builder; } @@ -1718,17 +1783,12 @@ public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Int4FlatIndexOptions that = (Int4FlatIndexOptions) o; - return Objects.equals(confidenceInterval, that.confidenceInterval) && Objects.equals(rescoreVector, that.rescoreVector); + return Objects.equals(confidenceInterval, that.confidenceInterval); } @Override public int doHashCode() { - return Objects.hash(confidenceInterval, rescoreVector); - } - - @Override - public String toString() { - return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}"; + return Objects.hash(confidenceInterval); } @Override @@ -1742,15 +1802,11 @@ boolean updatableTo(IndexOptions update) { } - static class Int8HnswIndexOptions extends QuantizedIndexOptions { - private final int m; - private final int efConstruction; + static class Int8HnswIndexOptions extends AbstractHnswIndexOptions { private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) { - super(VectorIndexType.INT8_HNSW, rescoreVector); - this.m = m; - this.efConstruction = efConstruction; + Int8HnswIndexOptions(int m, int efConstruction, int maxSearchEf, Float confidenceInterval, RescoreVector rescoreVector) { + super(VectorIndexType.INT8_HNSW, m, efConstruction, maxSearchEf, rescoreVector); this.confidenceInterval = confidenceInterval; } @@ -1761,50 +1817,25 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } - builder.endObject(); return builder; } @Override - public boolean doEquals(IndexOptions o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + boolean doChildEquals(IndexOptions o) { Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; - return m == that.m - && efConstruction == that.efConstruction - && Objects.equals(confidenceInterval, that.confidenceInterval) - && Objects.equals(rescoreVector, that.rescoreVector); - } - - @Override - public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector); + return Objects.equals(confidenceInterval, that.confidenceInterval); } @Override - public String toString() { - return "{type=" - + type - + ", m=" - + m - + ", ef_construction=" - + efConstruction - + ", confidence_interval=" - + confidenceInterval - + ", rescore_vector=" - + (rescoreVector == null ? "none" : rescoreVector) - + "}"; + int doChildHashCode() { + return Objects.hash(confidenceInterval); } @Override @@ -1825,14 +1856,10 @@ boolean updatableTo(IndexOptions update) { } } - static class HnswIndexOptions extends IndexOptions { - private final int m; - private final int efConstruction; + static class HnswIndexOptions extends AbstractHnswIndexOptions { - HnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.HNSW); - this.m = m; - this.efConstruction = efConstruction; + HnswIndexOptions(int m, int efConstruction, int maxSearchEf) { + super(VectorIndexType.HNSW, m, efConstruction, maxSearchEf, null); } @Override @@ -1857,42 +1884,25 @@ boolean updatableTo(IndexOptions update) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); - builder.endObject(); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { return builder; } @Override - public boolean doEquals(IndexOptions o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction; + boolean doChildEquals(IndexOptions o) { + return Objects.equals(type, o.type); } @Override - public int doHashCode() { - return Objects.hash(m, efConstruction); - } - - @Override - public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; + int doChildHashCode() { + return Objects.hash(type); } } - static class BBQHnswIndexOptions extends QuantizedIndexOptions { - private final int m; - private final int efConstruction; + static class BBQHnswIndexOptions extends AbstractHnswIndexOptions { - BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) { - super(VectorIndexType.BBQ_HNSW, rescoreVector); - this.m = m; - this.efConstruction = efConstruction; + BBQHnswIndexOptions(int m, int efConstruction, int maxSearchEf, RescoreVector rescoreVector) { + super(VectorIndexType.BBQ_HNSW, m, efConstruction, maxSearchEf, rescoreVector); } @Override @@ -1907,39 +1917,30 @@ boolean updatableTo(IndexOptions update) { } @Override - boolean doEquals(IndexOptions other) { - BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(rescoreVector, that.rescoreVector); + public void validateDimension(int dim) { + if (type.supportsDimension(dim)) { + return; + } + throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim); } @Override - int doHashCode() { - return Objects.hash(m, efConstruction, rescoreVector); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { + return builder; } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); - if (rescoreVector != null) { - rescoreVector.toXContent(builder, params); - } - builder.endObject(); - return builder; + boolean doChildEquals(IndexOptions o) { + return Objects.equals(type, o.type); } @Override - public void validateDimension(int dim) { - if (type.supportsDimension(dim)) { - return; - } - throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim); + int doChildHashCode() { + return Objects.hash(type); } } - static class BBQFlatIndexOptions extends QuantizedIndexOptions { + static class BBQFlatIndexOptions extends IndexOptions { private final int CLASS_NAME_HASH = this.getClass().getName().hashCode(); BBQFlatIndexOptions(RescoreVector rescoreVector) { @@ -1968,13 +1969,10 @@ int doHashCode() { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } - builder.endObject(); return builder; } @@ -2161,6 +2159,13 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } + + if (numCands > indexOptions.maxSearchEf()) { + throw new IllegalArgumentException( + "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + indexOptions.maxSearchEf() + "]" + ); + } + return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery( @@ -2262,15 +2267,13 @@ && isNotUnitVector(squaredMagnitude)) { // By default utilize the quantized oversample is configured // allow the user provided at query time overwrite Float oversample = queryOversample; - if (oversample == null - && indexOptions instanceof QuantizedIndexOptions quantizedIndexOptions - && quantizedIndexOptions.rescoreVector != null) { - oversample = quantizedIndexOptions.rescoreVector.oversample; + if (oversample == null && indexOptions.rescoreVector != null) { + oversample = indexOptions.rescoreVector.oversample; } boolean rescore = needsRescore(oversample); if (rescore) { // Will get k * oversample for rescoring, and get the top k - adjustedK = Math.min((int) Math.ceil(k * oversample), OVERSAMPLE_LIMIT); + adjustedK = Math.min((int) Math.ceil(k * oversample), indexOptions.maxSearchEf()); numCands = Math.max(adjustedK, numCands); } Query knnQuery = parentFilter != null diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 9b9718efcf523..c0e34904efb59 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -264,9 +264,6 @@ private KnnSearchBuilder( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]" ); } - if (numCandidates > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (queryVector == null && queryVectorBuilder == null) { throw new IllegalArgumentException( format( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 565fd7325a5ac..71d456ad445d4 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -183,9 +183,6 @@ private KnnVectorQueryBuilder( if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); } - if (numCands != null && numCands > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (k != null && numCands != null && numCands < k) { throw new IllegalArgumentException( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]" diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 036d2e62c8f48..c5e5e3d363967 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -64,6 +64,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -853,6 +854,23 @@ protected void registerParameters(ParameterChecker checker) throws IOException { .endObject() ) ); + + checker.registerUpdateCheck( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .startObject("index_options") + .field("type", "int4_hnsw") + .endObject(), + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .startObject("index_options") + .field("type", "int4_hnsw") + .field("max_search_ef", 1000) + .endObject(), + m -> assertTrue(m.toString().contains("\"max_search_ef\":1000")) + ); } @Override @@ -1052,6 +1070,7 @@ public void testMergeDims() throws IOException { .field("type", "int8_hnsw") .field("m", 16) .field("ef_construction", 100) + .field("max_search_ef", NUM_CANDS_LIMIT) .endObject(); b.endObject(); }); @@ -2205,6 +2224,22 @@ public void testInvalidVectorDimensions() { } } + public void testMaxSearchEfBounds() { + Exception e = expectThrows(MapperParsingException.class, () -> createDocumentMapper(fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "hnsw"); + b.field("m", 5); + b.field("ef_construction", 50); + b.field("max_search_ef", 0); // Invalid value + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Failed to parse mapping: [max_search_ef] must be greater than 0")); + } + @Override protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index e98038b7a0759..10533025ef1f5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -36,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.search.vectors.KnnSearchBuilder.NUM_CANDS_FIELD; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -55,23 +55,33 @@ private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() { private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ), new DenseVectorFieldMapper.FlatIndexOptions() ); } private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), @@ -87,6 +97,7 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())) @@ -98,18 +109,21 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ), new DenseVectorFieldMapper.BBQHnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()) ) ); @@ -398,6 +412,35 @@ public void testCreateKnnQueryMaxDims() { } } + public void testCreateKnnQuerValidateNumCandidates() { + int dims = randomIntBetween(BBQ_MIN_DIMS, 2048); + DenseVectorFieldType field = new DenseVectorFieldType( + "f", + IndexVersion.current(), + FLOAT, + dims, + true, + VectorSimilarity.COSINE, + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), 1000), + Collections.emptyMap() + ); + float[] queryVector = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVector[i] = randomFloat(); + } + + int numCands = 500; + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands, null, null, null, null); + assertThat(query, instanceOf(ESKnnFloatVectorQuery.class)); + + int numCands2 = 1500; + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands2, null, null, null, null) + ); + assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [1000]")); + } + public void testByteCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", @@ -473,6 +516,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { } public void testRescoreOversampleModifiesNumCandidates() { + DenseVectorFieldMapper.IndexOptions indexOptions = randomIndexOptionsHnswQuantized(); DenseVectorFieldType fieldType = new DenseVectorFieldType( "f", IndexVersion.current(), @@ -480,7 +524,7 @@ public void testRescoreOversampleModifiesNumCandidates() { 3, true, VectorSimilarity.COSINE, - randomIndexOptionsHnswQuantized(), + indexOptions, Collections.emptyMap() ); @@ -489,7 +533,7 @@ public void testRescoreOversampleModifiesNumCandidates() { // If numCands < k, update numCands to k checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10); // Oversampling limits for k - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000); + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, indexOptions.maxSearchEf(), indexOptions.maxSearchEf(), 1000); } private static void checkRescoreQueryParameters( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index b3764d528ff0f..1586835016e5f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -46,7 +46,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -199,7 +199,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Integer numCands = queryBuilder.numCands(); if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { Float oversample = queryBuilder.rescoreVectorBuilder().oversample(); - k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample)); + k = Math.min(NUM_CANDS_LIMIT, (int) Math.ceil(k * oversample)); numCands = Math.max(numCands, k); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 8cca3f9ed8a21..a85e9a933b288 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -238,14 +238,6 @@ public void testNumCandsLessThanK() { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() { - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) - ); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java index 4e4d2158a9574..bd1d192283f2c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java @@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() throws IOException { - XContentType xContentType = randomFrom(XContentType.values()); - XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) - .startObject() - .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) - .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") - .field(KnnSearch.K_FIELD.getPreferredName(), 100) - .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002) - .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) - .endObject() - .endObject(); - - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder)); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() throws IOException { XContentType xContentType = randomFrom(XContentType.values()); XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())