Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ private KnnSearchBuilder(
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]"
);
}
if (numCandidates > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (queryVector == null && queryVectorBuilder == null) {
throw new IllegalArgumentException(
format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ private KnnVectorQueryBuilder(
if (k != null && k < 1) {
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
}
if (numCands != null && numCands > NUM_CANDS_LIMIT) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}
if (k != null && numCands != null && numCands < k) {
throw new IllegalArgumentException(
"[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -853,6 +854,23 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
.endObject()
)
);

checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int4_hnsw")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int4_hnsw")
.field("max_search_ef", 1000)
.endObject(),
m -> assertTrue(m.toString().contains("\"max_search_ef\":1000"))
);
}

@Override
Expand Down Expand Up @@ -1052,6 +1070,7 @@ public void testMergeDims() throws IOException {
.field("type", "int8_hnsw")
.field("m", 16)
.field("ef_construction", 100)
.field("max_search_ef", NUM_CANDS_LIMIT)
.endObject();
b.endObject();
});
Expand Down Expand Up @@ -2205,6 +2224,22 @@ public void testInvalidVectorDimensions() {
}
}

public void testMaxSearchEfBounds() {
Exception e = expectThrows(MapperParsingException.class, () -> createDocumentMapper(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", dims);
b.field("index", true);
b.field("similarity", "dot_product");
b.startObject("index_options");
b.field("type", "hnsw");
b.field("m", 5);
b.field("ef_construction", 50);
b.field("max_search_ef", 0); // Invalid value
b.endObject();
})));
assertThat(e.getMessage(), containsString("Failed to parse mapping: [max_search_ef] must be greater than 0"));
}

@Override
protected IngestScriptSupport ingestScriptSupport() {
throw new AssumptionViolatedException("not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
import static org.elasticsearch.search.vectors.KnnSearchBuilder.NUM_CANDS_FIELD;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand All @@ -55,23 +55,33 @@ private static DenseVectorFieldMapper.RescoreVector randomRescoreVector() {

private DenseVectorFieldMapper.IndexOptions randomIndexOptionsNonQuantized() {
return randomFrom(
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)),
new DenseVectorFieldMapper.HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000)
),
new DenseVectorFieldMapper.FlatIndexOptions()
);
}

private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
return randomFrom(
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)),
new DenseVectorFieldMapper.HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000)
),
new DenseVectorFieldMapper.Int8HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
),
new DenseVectorFieldMapper.Int4HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
),
Expand All @@ -87,6 +97,7 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
new DenseVectorFieldMapper.BBQHnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
),
new DenseVectorFieldMapper.BBQFlatIndexOptions(randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector()))
Expand All @@ -98,18 +109,21 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() {
new DenseVectorFieldMapper.Int8HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
),
new DenseVectorFieldMapper.Int4HnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
),
new DenseVectorFieldMapper.BBQHnswIndexOptions(
randomIntBetween(1, 100),
randomIntBetween(1, 10_000),
randomIntBetween(1_000, 10_000),
randomFrom((DenseVectorFieldMapper.RescoreVector) null, randomRescoreVector())
)
);
Expand Down Expand Up @@ -398,6 +412,35 @@ public void testCreateKnnQueryMaxDims() {
}
}

public void testCreateKnnQuerValidateNumCandidates() {
int dims = randomIntBetween(BBQ_MIN_DIMS, 2048);
DenseVectorFieldType field = new DenseVectorFieldType(
"f",
IndexVersion.current(),
FLOAT,
dims,
true,
VectorSimilarity.COSINE,
new DenseVectorFieldMapper.HnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000), 1000),
Collections.emptyMap()
);
float[] queryVector = new float[dims];
for (int i = 0; i < dims; i++) {
queryVector[i] = randomFloat();
}

int numCands = 500;
Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands, null, null, null, null);
assertThat(query, instanceOf(ESKnnFloatVectorQuery.class));

int numCands2 = 1500;
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> field.createKnnQuery(VectorData.fromFloats(queryVector), 10, numCands2, null, null, null, null)
);
assertThat(e.getMessage(), containsString("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [1000]"));
}

public void testByteCreateKnnQuery() {
DenseVectorFieldType unindexedField = new DenseVectorFieldType(
"f",
Expand Down Expand Up @@ -473,14 +516,15 @@ public void testRescoreOversampleUsedWithoutQuantization() {
}

public void testRescoreOversampleModifiesNumCandidates() {
DenseVectorFieldMapper.IndexOptions indexOptions = randomIndexOptionsHnswQuantized();
DenseVectorFieldType fieldType = new DenseVectorFieldType(
"f",
IndexVersion.current(),
FLOAT,
3,
true,
VectorSimilarity.COSINE,
randomIndexOptionsHnswQuantized(),
indexOptions,
Collections.emptyMap()
);

Expand All @@ -489,7 +533,7 @@ public void testRescoreOversampleModifiesNumCandidates() {
// If numCands < k, update numCands to k
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10);
// Oversampling limits for k
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000);
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, indexOptions.maxSearchEf(), indexOptions.maxSearchEf(), 1000);
}

private static void checkRescoreQueryParameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_LIMIT;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -199,7 +199,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
Integer numCands = queryBuilder.numCands();
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
Float oversample = queryBuilder.rescoreVectorBuilder().oversample();
k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample));
k = Math.min(NUM_CANDS_LIMIT, (int) Math.ceil(k * oversample));
numCands = Math.max(numCands, k);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,6 @@ public void testNumCandsLessThanK() {
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
}

public void testNumCandsExceedsLimit() {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null)
);
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
}

public void testInvalidK() {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException {
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
}

public void testNumCandsExceedsLimit() throws IOException {
XContentType xContentType = randomFrom(XContentType.values());
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
.startObject()
.startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName())
.field(KnnSearch.FIELD_FIELD.getPreferredName(), "field")
.field(KnnSearch.K_FIELD.getPreferredName(), 100)
.field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002)
.field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f })
.endObject()
.endObject();

IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder));
assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
}

public void testInvalidK() throws IOException {
XContentType xContentType = randomFrom(XContentType.values());
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent())
Expand Down