From 1f30eaa83109233bbe806bc1702bcbcae70acd9b Mon Sep 17 00:00:00 2001 From: weizijun Date: Thu, 13 Mar 2025 19:44:00 +0800 Subject: [PATCH 1/4] improve IndexOptions --- .../vectors/DenseVectorFieldMapper.java | 187 ++++++++---------- 1 file changed, 82 insertions(+), 105 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 54cae36dd1647..27a6c3bd28389 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1236,6 +1236,17 @@ public void validateDimension(int dim) { throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); } + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("type", type); + innerXContent(builder, params); + builder.endObject(); + return builder; + } + + abstract public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException; + abstract boolean doEquals(IndexOptions other); abstract int doHashCode(); @@ -1258,6 +1269,53 @@ public final int hashCode() { } } + abstract static class AbstractHnswIndexOptions extends IndexOptions { + protected final int m; + protected final int efConstruction; + + AbstractHnswIndexOptions(VectorIndexType type, int m, int efConstruction) { + super(type); + this.m = m; + this.efConstruction = efConstruction; + } + + public int getM() { + return m; + } + + public int getEfConstruction() { + return efConstruction; + } + + @Override + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("m", m); + builder.field("ef_construction", efConstruction); + innerHnswXContent(builder, params); + return builder; + } + + abstract public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException; + + @Override + public boolean doEquals(IndexOptions o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + HnswIndexOptions that = (HnswIndexOptions) o; + return m == that.m && efConstruction == that.efConstruction; + } + + @Override + public int doHashCode() { + return Objects.hash(m, efConstruction); + } + + @Override + public String toString() { + return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; + } + } + public enum VectorIndexType { HNSW("hnsw", false) { @Override @@ -1492,13 +1550,10 @@ static class Int8FlatIndexOptions extends IndexOptions { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } - builder.endObject(); return builder; } @@ -1536,10 +1591,7 @@ static class FlatIndexOptions extends IndexOptions { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.endObject(); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { return builder; } @@ -1567,15 +1619,11 @@ public int doHashCode() { } } - static class Int4HnswIndexOptions extends IndexOptions { - private final int m; - private final int efConstruction; + static class Int4HnswIndexOptions extends AbstractHnswIndexOptions { private final float confidenceInterval; Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT4_HNSW); - this.m = m; - this.efConstruction = efConstruction; + super(VectorIndexType.INT4_HNSW, m, efConstruction); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1588,25 +1636,19 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { builder.field("confidence_interval", confidenceInterval); - builder.endObject(); return builder; } @Override public boolean doEquals(IndexOptions o) { - Int4HnswIndexOptions that = (Int4HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return super.doEquals(o) && Objects.equals(confidenceInterval, ((Int4HnswIndexOptions) o).confidenceInterval); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return super.doHashCode() + Objects.hash(confidenceInterval); } @Override @@ -1652,11 +1694,8 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { builder.field("confidence_interval", confidenceInterval); - builder.endObject(); return builder; } @@ -1689,15 +1728,11 @@ boolean updatableTo(IndexOptions update) { } - static class Int8HnswIndexOptions extends IndexOptions { - private final int m; - private final int efConstruction; + static class Int8HnswIndexOptions extends AbstractHnswIndexOptions { private final Float confidenceInterval; Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT8_HNSW); - this.m = m; - this.efConstruction = efConstruction; + super(VectorIndexType.INT8_HNSW, m, efConstruction); this.confidenceInterval = confidenceInterval; } @@ -1708,29 +1743,21 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { if (confidenceInterval != null) { builder.field("confidence_interval", confidenceInterval); } - builder.endObject(); return builder; } @Override public boolean doEquals(IndexOptions o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Int8HnswIndexOptions that = (Int8HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction && Objects.equals(confidenceInterval, that.confidenceInterval); + return super.doEquals(o) && Objects.equals(confidenceInterval, ((Int8HnswIndexOptions) o).confidenceInterval); } @Override public int doHashCode() { - return Objects.hash(m, efConstruction, confidenceInterval); + return super.doHashCode() + Objects.hash(confidenceInterval); } @Override @@ -1764,14 +1791,10 @@ boolean updatableTo(IndexOptions update) { } } - static class HnswIndexOptions extends IndexOptions { - private final int m; - private final int efConstruction; + static class HnswIndexOptions extends AbstractHnswIndexOptions { HnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.HNSW); - this.m = m; - this.efConstruction = efConstruction; + super(VectorIndexType.HNSW, m, efConstruction); } @Override @@ -1796,42 +1819,15 @@ boolean updatableTo(IndexOptions update) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); - builder.endObject(); + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { return builder; } - - @Override - public boolean doEquals(IndexOptions o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction; - } - - @Override - public int doHashCode() { - return Objects.hash(m, efConstruction); - } - - @Override - public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; - } } - static class BBQHnswIndexOptions extends IndexOptions { - private final int m; - private final int efConstruction; + static class BBQHnswIndexOptions extends AbstractHnswIndexOptions { BBQHnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.BBQ_HNSW); - this.m = m; - this.efConstruction = efConstruction; + super(VectorIndexType.BBQ_HNSW, m, efConstruction); } @Override @@ -1845,27 +1841,6 @@ boolean updatableTo(IndexOptions update) { return update.type.equals(this.type); } - @Override - boolean doEquals(IndexOptions other) { - BBQHnswIndexOptions that = (BBQHnswIndexOptions) other; - return m == that.m && efConstruction == that.efConstruction; - } - - @Override - int doHashCode() { - return Objects.hash(m, efConstruction); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.field("m", m); - builder.field("ef_construction", efConstruction); - builder.endObject(); - return builder; - } - @Override public void validateDimension(int dim) { if (type.supportsDimension(dim)) { @@ -1873,6 +1848,11 @@ public void validateDimension(int dim) { } throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim); } + + @Override + public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } } static class BBQFlatIndexOptions extends IndexOptions { @@ -1904,10 +1884,7 @@ int doHashCode() { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("type", type); - builder.endObject(); + public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { return builder; } From 2639d28a58254d195c54a7ffad5afee14aa52cc1 Mon Sep 17 00:00:00 2001 From: weizijun Date: Mon, 17 Mar 2025 13:44:21 +0800 Subject: [PATCH 2/4] add max_search_ef check --- .../vectors/DenseVectorFieldMapper.java | 87 ++++++++++++++----- .../vectors/DenseVectorFieldMapperTests.java | 1 + .../vectors/DenseVectorFieldTypeTests.java | 83 ++++++++++++++++-- ...AbstractKnnVectorQueryBuilderTestCase.java | 4 +- 4 files changed, 143 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 27a6c3bd28389..6bb54c18db205 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -94,6 +94,7 @@ import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.index.IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW; +import static org.elasticsearch.search.vectors.KnnSearchBuilder.NUM_CANDS_FIELD; /** * A {@link FieldMapper} for indexing a dense vector of floats. @@ -102,6 +103,7 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; private static final float EPS = 1e-3f; public static final int BBQ_MIN_DIMS = 64; + public static final int NUM_CANDS_LIMIT = 10_000; public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; @@ -120,7 +122,6 @@ public static boolean isNotUnitVector(float magnitude) { public static final short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to // vector public static final int MAGNITUDE_BYTES = 4; - public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -210,6 +211,7 @@ public Builder(String name, IndexVersion indexVersionCreated) { ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + NUM_CANDS_LIMIT, null ) : null, @@ -1236,6 +1238,14 @@ public void validateDimension(int dim) { throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); } + public void validateNumCandidates(int numCands) { + + } + + public int maxSearchEf() { + return Integer.MAX_VALUE; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -1272,25 +1282,32 @@ public final int hashCode() { abstract static class AbstractHnswIndexOptions extends IndexOptions { protected final int m; protected final int efConstruction; + protected final int maxSearchEf; - AbstractHnswIndexOptions(VectorIndexType type, int m, int efConstruction) { + AbstractHnswIndexOptions(VectorIndexType type, int m, int efConstruction, int maxSearchEf) { super(type); this.m = m; this.efConstruction = efConstruction; + this.maxSearchEf = maxSearchEf; } - public int getM() { - return m; + @Override + public void validateNumCandidates(int numCands) { + if (numCands > maxSearchEf) { + throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]"); + } } - public int getEfConstruction() { - return efConstruction; + @Override + public int maxSearchEf() { + return maxSearchEf; } @Override public XContentBuilder innerXContent(XContentBuilder builder, Params params) throws IOException { builder.field("m", m); builder.field("ef_construction", efConstruction); + builder.field("max_search_ef", maxSearchEf); innerHnswXContent(builder, params); return builder; } @@ -1301,18 +1318,18 @@ public XContentBuilder innerXContent(XContentBuilder builder, Params params) thr public boolean doEquals(IndexOptions o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - HnswIndexOptions that = (HnswIndexOptions) o; - return m == that.m && efConstruction == that.efConstruction; + AbstractHnswIndexOptions that = (AbstractHnswIndexOptions) o; + return m == that.m && efConstruction == that.efConstruction && maxSearchEf == that.maxSearchEf; } @Override public int doHashCode() { - return Objects.hash(m, efConstruction); + return Objects.hash(m, efConstruction, maxSearchEf); } @Override public String toString() { - return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}"; + return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + ", max_search_ef=" + "}"; } } @@ -1322,16 +1339,22 @@ public enum VectorIndexType { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); + MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new HnswIndexOptions(m, efConstruction); + return new HnswIndexOptions(m, efConstruction, maxSearchEf); } @Override @@ -1349,6 +1372,7 @@ public boolean supportsDimension(int dims) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1356,14 +1380,18 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int8HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int8HnswIndexOptions(m, efConstruction, maxSearchEf, confidenceInterval); } @Override @@ -1380,6 +1408,7 @@ public boolean supportsDimension(int dims) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1387,14 +1416,18 @@ public IndexOptions parseIndexOptions(String fieldName, Map indexOpti if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); Float confidenceInterval = null; if (confidenceIntervalNode != null) { confidenceInterval = (float) XContentMapValues.nodeDoubleValue(confidenceIntervalNode); } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new Int4HnswIndexOptions(m, efConstruction, confidenceInterval); + return new Int4HnswIndexOptions(m, efConstruction, maxSearchEf, confidenceInterval); } @Override @@ -1473,16 +1506,21 @@ public boolean supportsDimension(int dims) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); + Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } if (efConstructionNode == null) { efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; } + if (maxSearchEfNode == null) { + maxSearchEfNode = NUM_CANDS_LIMIT; + } int m = XContentMapValues.nodeIntegerValue(mNode); int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode); + int maxSearchEf = XContentMapValues.nodeIntegerValue(maxSearchEfNode); MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQHnswIndexOptions(m, efConstruction); + return new BBQHnswIndexOptions(m, efConstruction, maxSearchEf); } @Override @@ -1622,8 +1660,8 @@ public int doHashCode() { static class Int4HnswIndexOptions extends AbstractHnswIndexOptions { private final float confidenceInterval; - Int4HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT4_HNSW, m, efConstruction); + Int4HnswIndexOptions(int m, int efConstruction, int maxSearchEf, Float confidenceInterval) { + super(VectorIndexType.INT4_HNSW, m, efConstruction, maxSearchEf); // The default confidence interval for int4 is dynamic quantiles, this provides the best relevancy and is // effectively required for int4 to behave well across a wide range of data. this.confidenceInterval = confidenceInterval == null ? 0f : confidenceInterval; @@ -1731,8 +1769,8 @@ boolean updatableTo(IndexOptions update) { static class Int8HnswIndexOptions extends AbstractHnswIndexOptions { private final Float confidenceInterval; - Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval) { - super(VectorIndexType.INT8_HNSW, m, efConstruction); + Int8HnswIndexOptions(int m, int efConstruction, int maxSearchEf, Float confidenceInterval) { + super(VectorIndexType.INT8_HNSW, m, efConstruction, maxSearchEf); this.confidenceInterval = confidenceInterval; } @@ -1793,8 +1831,8 @@ boolean updatableTo(IndexOptions update) { static class HnswIndexOptions extends AbstractHnswIndexOptions { - HnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.HNSW, m, efConstruction); + HnswIndexOptions(int m, int efConstruction, int maxSearchEf) { + super(VectorIndexType.HNSW, m, efConstruction, maxSearchEf); } @Override @@ -1826,8 +1864,8 @@ public XContentBuilder innerHnswXContent(XContentBuilder builder, Params params) static class BBQHnswIndexOptions extends AbstractHnswIndexOptions { - BBQHnswIndexOptions(int m, int efConstruction) { - super(VectorIndexType.BBQ_HNSW, m, efConstruction); + BBQHnswIndexOptions(int m, int efConstruction, int maxSearchEf) { + super(VectorIndexType.BBQ_HNSW, m, efConstruction, maxSearchEf); } @Override @@ -2036,6 +2074,9 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } + + indexOptions.validateNumCandidates(numCands); + return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery( @@ -2137,7 +2178,7 @@ && isNotUnitVector(squaredMagnitude)) { boolean rescore = needsRescore(oversample); if (rescore) { // Will get k * oversample for rescoring, and get the top k - adjustedK = Math.min((int) Math.ceil(k * oversample), OVERSAMPLE_LIMIT); + adjustedK = Math.min((int) Math.ceil(k * oversample), indexOptions.maxSearchEf()); numCands = Math.max(adjustedK, numCands); } Query knnQuery = parentFilter != null diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 3f574a29469c2..bfd91543a9148 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -937,6 +937,7 @@ public void testMergeDims() throws IOException { .field("type", "int8_hnsw") .field("m", 16) .field("ef_construction", 100) + .field("max_search_ef", DenseVectorFieldMapper.NUM_CANDS_LIMIT) .endObject(); b.endObject(); }); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5c067cb2d0a27..81120c4fec247 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -36,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.search.vectors.KnnSearchBuilder.NUM_CANDS_FIELD; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -51,28 +51,30 @@ public DenseVectorFieldTypeTests() { private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), new DenseVectorFieldMapper.FlatIndexOptions() ); } private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) ), new DenseVectorFieldMapper.FlatIndexOptions(), new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)), + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), new DenseVectorFieldMapper.BBQFlatIndexOptions() ); } @@ -82,14 +84,16 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) ), new DenseVectorFieldMapper.Int4HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)) ); } @@ -172,6 +176,37 @@ public void testFetchSourceValue() throws IOException { assertEquals(vector, fetchSourceValue(bft, vector)); } + public void testValidateNumCandidates() { + // Test case where numCands is less than or equal to maxSearchEf + { + int maxSearchEf = randomIntBetween(1, 1000); + int numCands = randomIntBetween(0, maxSearchEf); + DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + maxSearchEf + ); + indexOptions.validateNumCandidates(numCands); + // No exception should be thrown + } + + // Test case where numCands is greater than maxSearchEf + { + int maxSearchEf = randomIntBetween(1, 1000); + int numCands = randomIntBetween(maxSearchEf + 1, maxSearchEf + 1000); + DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + maxSearchEf + ); + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> indexOptions.validateNumCandidates(numCands) + ); + assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]")); + } + } + public void testCreateNestedKnnQuery() { BitSetProducer producer = context -> null; @@ -370,6 +405,39 @@ public void testCreateKnnQueryMaxDims() { } } + public void testCreateKnnQuerValidateNumCandidates() { + int dims = randomIntBetween(BBQ_MIN_DIMS, 2048); + DenseVectorFieldType field = new DenseVectorFieldType( + "f", + IndexVersion.current(), + FLOAT, + dims, + true, + VectorSimilarity.COSINE, + new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + 1000 + ), + Collections.emptyMap() + ); + float[] queryVector = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVector[i] = randomFloat(); + } + + int numCands = 500; + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands, null, null, null, null); + assertThat(query, instanceOf(ESKnnFloatVectorQuery.class)); + + int numCands2 = 1500; + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands2, null, null, null, null) + ); + assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [1000]")); + } + public void testByteCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", @@ -445,6 +513,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { } public void testRescoreOversampleModifiesNumCandidates() { + DenseVectorFieldMapper.IndexOptions indexOptions = randomIndexOptionsHnswQuantized(); DenseVectorFieldType fieldType = new DenseVectorFieldType( "f", IndexVersion.current(), @@ -452,7 +521,7 @@ public void testRescoreOversampleModifiesNumCandidates() { 3, true, VectorSimilarity.COSINE, - randomIndexOptionsHnswQuantized(), + indexOptions, Collections.emptyMap() ); @@ -461,7 +530,7 @@ public void testRescoreOversampleModifiesNumCandidates() { // If numCands < k, update numCands to k checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10); // Oversampling limits for k - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000); + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, indexOptions.maxSearchEf(), indexOptions.maxSearchEf(), 1000); } private static void checkRescoreQueryParameters( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index b3764d528ff0f..1586835016e5f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -46,7 +46,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -199,7 +199,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Integer numCands = queryBuilder.numCands(); if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { Float oversample = queryBuilder.rescoreVectorBuilder().oversample(); - k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample)); + k = Math.min(NUM_CANDS_LIMIT, (int) Math.ceil(k * oversample)); numCands = Math.max(numCands, k); } From 59d6c6c61dee9d61957b2e3c33739834dc24903d Mon Sep 17 00:00:00 2001 From: weizijun Date: Mon, 17 Mar 2025 17:09:59 +0800 Subject: [PATCH 3/4] improve --- .../vectors/DenseVectorFieldMapper.java | 45 ++++++++------ .../search/vectors/KnnSearchBuilder.java | 3 - .../search/vectors/KnnVectorQueryBuilder.java | 3 - .../vectors/DenseVectorFieldMapperTests.java | 36 ++++++++++- .../vectors/DenseVectorFieldTypeTests.java | 61 +++++++------------ .../search/vectors/KnnSearchBuilderTests.java | 8 --- .../vectors/KnnSearchRequestParserTests.java | 16 ----- 7 files changed, 84 insertions(+), 88 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 6bb54c18db205..97e2ac5290b33 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -123,6 +123,10 @@ public static boolean isNotUnitVector(float magnitude) { // vector public static final int MAGNITUDE_BYTES = 4; + public static final String M_FIELD = "m"; + public static final String EF_CONSTRUCTION_FIELD = "ef_construction"; + public static final String MAX_SEARCH_EF_FIELD = "max_search_ef"; + private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; } @@ -230,6 +234,9 @@ public Builder(String name, IndexVersion indexVersionCreated) { if (v != null) { v.validateElementType(elementType.getValue()); } + if (v != null) { + v.validateNumCandidates(); + } }) .acceptsNull() .setMergeValidator( @@ -1238,7 +1245,7 @@ public void validateDimension(int dim) { throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim); } - public void validateNumCandidates(int numCands) { + public void validateNumCandidates() { } @@ -1292,9 +1299,9 @@ abstract static class AbstractHnswIndexOptions extends IndexOptions { } @Override - public void validateNumCandidates(int numCands) { - if (numCands > maxSearchEf) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]"); + public void validateNumCandidates() { + if (maxSearchEf <= 0) { + throw new IllegalArgumentException("[" + MAX_SEARCH_EF_FIELD + "] must be greater than 0"); } } @@ -1337,9 +1344,9 @@ public enum VectorIndexType { HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } @@ -1370,9 +1377,9 @@ public boolean supportsDimension(int dims) { INT8_HNSW("int8_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1406,9 +1413,9 @@ public boolean supportsDimension(int dims) { }, INT4_HNSW("int4_hnsw", true) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; @@ -1504,9 +1511,9 @@ public boolean supportsDimension(int dims) { BBQ_HNSW("bbq_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { - Object mNode = indexOptionsMap.remove("m"); - Object efConstructionNode = indexOptionsMap.remove("ef_construction"); - Object maxSearchEfNode = indexOptionsMap.remove("max_search_ef"); + Object mNode = indexOptionsMap.remove(M_FIELD); + Object efConstructionNode = indexOptionsMap.remove(EF_CONSTRUCTION_FIELD); + Object maxSearchEfNode = indexOptionsMap.remove(MAX_SEARCH_EF_FIELD); if (mNode == null) { mNode = Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; } @@ -2075,7 +2082,11 @@ public Query createKnnQuery( ); } - indexOptions.validateNumCandidates(numCands); + if (numCands > indexOptions.maxSearchEf()) { + throw new IllegalArgumentException( + "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + indexOptions.maxSearchEf() + "]" + ); + } return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 9b9718efcf523..c0e34904efb59 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -264,9 +264,6 @@ private KnnSearchBuilder( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]" ); } - if (numCandidates > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (queryVector == null && queryVectorBuilder == null) { throw new IllegalArgumentException( format( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 565fd7325a5ac..71d456ad445d4 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -183,9 +183,6 @@ private KnnVectorQueryBuilder( if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); } - if (numCands != null && numCands > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (k != null && numCands != null && numCands < k) { throw new IllegalArgumentException( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]" diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index bfd91543a9148..c1261101ae19e 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -63,6 +63,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -852,6 +853,23 @@ protected void registerParameters(ParameterChecker checker) throws IOException { .endObject() ) ); + + checker.registerUpdateCheck( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .startObject("index_options") + .field("type", "int4_hnsw") + .endObject(), + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .startObject("index_options") + .field("type", "int4_hnsw") + .field("max_search_ef", 1000) + .endObject(), + m -> assertTrue(m.toString().contains("\"max_search_ef\":1000")) + ); } @Override @@ -937,7 +955,7 @@ public void testMergeDims() throws IOException { .field("type", "int8_hnsw") .field("m", 16) .field("ef_construction", 100) - .field("max_search_ef", DenseVectorFieldMapper.NUM_CANDS_LIMIT) + .field("max_search_ef", NUM_CANDS_LIMIT) .endObject(); b.endObject(); }); @@ -2091,6 +2109,22 @@ public void testInvalidVectorDimensions() { } } + public void testMaxSearchEfBounds() { + Exception e = expectThrows(MapperParsingException.class, () -> createDocumentMapper(fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", dims); + b.field("index", true); + b.field("similarity", "dot_product"); + b.startObject("index_options"); + b.field("type", "hnsw"); + b.field("m", 5); + b.field("ef_construction", 50); + b.field("max_search_ef", 0); // Invalid value + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Failed to parse mapping: [max_search_ef] must be greater than 0")); + } + @Override protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 81120c4fec247..c9b5fad53a09b 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -51,14 +51,22 @@ public DenseVectorFieldTypeTests() { private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ), new DenseVectorFieldMapper.FlatIndexOptions() ); } private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { return randomFrom( - new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), + new DenseVectorFieldMapper.HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ), new DenseVectorFieldMapper.Int8HnswIndexOptions( randomIntBetween(1, 100), randomIntBetween(1, 10_000), @@ -74,7 +82,11 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { new DenseVectorFieldMapper.FlatIndexOptions(), new DenseVectorFieldMapper.Int8FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), new DenseVectorFieldMapper.Int4FlatIndexOptions(randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true))), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)), + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ), new DenseVectorFieldMapper.BBQFlatIndexOptions() ); } @@ -93,7 +105,11 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { randomIntBetween(1_000, 10_000), randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) ), - new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), randomIntBetween(1_000, 10_000)) + new DenseVectorFieldMapper.BBQHnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomIntBetween(1_000, 10_000) + ) ); } @@ -176,37 +192,6 @@ public void testFetchSourceValue() throws IOException { assertEquals(vector, fetchSourceValue(bft, vector)); } - public void testValidateNumCandidates() { - // Test case where numCands is less than or equal to maxSearchEf - { - int maxSearchEf = randomIntBetween(1, 1000); - int numCands = randomIntBetween(0, maxSearchEf); - DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - maxSearchEf - ); - indexOptions.validateNumCandidates(numCands); - // No exception should be thrown - } - - // Test case where numCands is greater than maxSearchEf - { - int maxSearchEf = randomIntBetween(1, 1000); - int numCands = randomIntBetween(maxSearchEf + 1, maxSearchEf + 1000); - DenseVectorFieldMapper.IndexOptions indexOptions = new DenseVectorFieldMapper.HnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - maxSearchEf - ); - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> indexOptions.validateNumCandidates(numCands) - ); - assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxSearchEf + "]")); - } - } - public void testCreateNestedKnnQuery() { BitSetProducer producer = context -> null; @@ -414,11 +399,7 @@ public void testCreateKnnQuerValidateNumCandidates() { dims, true, VectorSimilarity.COSINE, - new DenseVectorFieldMapper.HnswIndexOptions( - randomIntBetween(1, 100), - randomIntBetween(1, 10_000), - 1000 - ), + new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), 1000), Collections.emptyMap() ); float[] queryVector = new float[dims]; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 8cca3f9ed8a21..a85e9a933b288 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -238,14 +238,6 @@ public void testNumCandsLessThanK() { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() { - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) - ); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java index 4e4d2158a9574..bd1d192283f2c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java @@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() throws IOException { - XContentType xContentType = randomFrom(XContentType.values()); - XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) - .startObject() - .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) - .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") - .field(KnnSearch.K_FIELD.getPreferredName(), 100) - .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002) - .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) - .endObject() - .endObject(); - - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder)); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() throws IOException { XContentType xContentType = randomFrom(XContentType.values()); XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) From 521162892257ae8fdcd6de119e2f5905e4f107d8 Mon Sep 17 00:00:00 2001 From: weizijun Date: Mon, 17 Mar 2025 18:33:42 +0800 Subject: [PATCH 4/4] spotless --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 1c56cf0133db1..f22e50e391382 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1289,7 +1289,9 @@ public final boolean equals(Object other) { return false; } IndexOptions otherOptions = (IndexOptions) other; - return Objects.equals(type, otherOptions.type) && Objects.equals(rescoreVector, otherOptions.rescoreVector) && doEquals(otherOptions); + return Objects.equals(type, otherOptions.type) + && Objects.equals(rescoreVector, otherOptions.rescoreVector) + && doEquals(otherOptions); } @Override @@ -2265,8 +2267,7 @@ && isNotUnitVector(squaredMagnitude)) { // By default utilize the quantized oversample is configured // allow the user provided at query time overwrite Float oversample = queryOversample; - if (oversample == null - && indexOptions.rescoreVector != null) { + if (oversample == null && indexOptions.rescoreVector != null) { oversample = indexOptions.rescoreVector.oversample; } boolean rescore = needsRescore(oversample);