Skip to content

Commit 248498d

Browse files
authored
Simplify vector field classes (#96925)
In order to customize the maximum number of vector dimensions, the only thing that's needed is to create a custom FieldType that works around the max num dimensions limit in the Lucene field type. There is no need to have custom field classes, which are removed with this commit.
1 parent 2f93cf5 commit 248498d

File tree

8 files changed

+74
-158
lines changed

8 files changed

+74
-158
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
1313
import org.apache.lucene.document.BinaryDocValuesField;
1414
import org.apache.lucene.document.Field;
15+
import org.apache.lucene.document.FieldType;
1516
import org.apache.lucene.document.KnnByteVectorField;
1617
import org.apache.lucene.document.KnnFloatVectorField;
1718
import org.apache.lucene.index.BinaryDocValues;
1819
import org.apache.lucene.index.ByteVectorValues;
1920
import org.apache.lucene.index.FloatVectorValues;
2021
import org.apache.lucene.index.LeafReader;
22+
import org.apache.lucene.index.VectorEncoding;
2123
import org.apache.lucene.index.VectorSimilarityFunction;
2224
import org.apache.lucene.search.FieldExistsQuery;
2325
import org.apache.lucene.search.KnnByteVectorQuery;
@@ -171,6 +173,40 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
171173
}
172174
}
173175

176+
private static FieldType getDenseVectorFieldType(
177+
int dimension,
178+
VectorEncoding vectorEncoding,
179+
VectorSimilarityFunction similarityFunction
180+
) {
181+
if (dimension == 0) {
182+
throw new IllegalArgumentException("cannot index an empty vector");
183+
}
184+
if (dimension > DenseVectorFieldMapper.MAX_DIMS_COUNT) {
185+
throw new IllegalArgumentException("cannot index vectors with dimension greater than " + DenseVectorFieldMapper.MAX_DIMS_COUNT);
186+
}
187+
if (similarityFunction == null) {
188+
throw new IllegalArgumentException("similarity function must not be null");
189+
}
190+
FieldType fieldType = new FieldType() {
191+
@Override
192+
public int vectorDimension() {
193+
return dimension;
194+
}
195+
196+
@Override
197+
public VectorEncoding vectorEncoding() {
198+
return vectorEncoding;
199+
}
200+
201+
@Override
202+
public VectorSimilarityFunction vectorSimilarityFunction() {
203+
return similarityFunction;
204+
}
205+
};
206+
fieldType.freeze();
207+
return fieldType;
208+
}
209+
174210
public enum ElementType {
175211

176212
BYTE(1) {
@@ -192,7 +228,11 @@ public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws I
192228

193229
@Override
194230
KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
195-
return new XKnnByteVectorField(name, vector, function);
231+
if (vector == null) {
232+
throw new IllegalArgumentException("vector value must not be null");
233+
}
234+
FieldType denseVectorFieldType = getDenseVectorFieldType(vector.length, VectorEncoding.BYTE, function);
235+
return new KnnByteVectorField(name, vector, denseVectorFieldType);
196236
}
197237

198238
@Override
@@ -382,7 +422,11 @@ public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws I
382422

383423
@Override
384424
KnnFloatVectorField createKnnVectorField(String name, float[] vector, VectorSimilarityFunction function) {
385-
return new XKnnFloatVectorField(name, vector, function);
425+
if (vector == null) {
426+
throw new IllegalArgumentException("vector value must not be null");
427+
}
428+
FieldType denseVectorFieldType = getDenseVectorFieldType(vector.length, VectorEncoding.FLOAT32, function);
429+
return new KnnFloatVectorField(name, vector, denseVectorFieldType);
386430
}
387431

388432
@Override

server/src/main/java/org/elasticsearch/index/mapper/vectors/XKnnByteVectorField.java

Lines changed: 0 additions & 65 deletions
This file was deleted.

server/src/main/java/org/elasticsearch/index/mapper/vectors/XKnnFloatVectorField.java

Lines changed: 0 additions & 64 deletions
This file was deleted.

server/src/test/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzerTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.document.Field;
2525
import org.apache.lucene.document.FieldType;
2626
import org.apache.lucene.document.IntPoint;
27+
import org.apache.lucene.document.KnnFloatVectorField;
2728
import org.apache.lucene.document.LatLonShape;
2829
import org.apache.lucene.document.LongPoint;
2930
import org.apache.lucene.document.NumericDocValuesField;
@@ -66,7 +67,6 @@
6667
import org.apache.lucene.util.FixedBitSet;
6768
import org.elasticsearch.common.lucene.Lucene;
6869
import org.elasticsearch.core.IOUtils;
69-
import org.elasticsearch.index.mapper.vectors.XKnnFloatVectorField;
7070
import org.elasticsearch.index.shard.ShardId;
7171
import org.elasticsearch.index.store.LuceneFilesExtensions;
7272
import org.elasticsearch.test.ESTestCase;
@@ -257,7 +257,7 @@ public void testKnnVectors() throws Exception {
257257

258258
indexRandomly(dir, codec, numDocs, doc -> {
259259
float[] vector = randomVector(dimension);
260-
doc.add(new XKnnFloatVectorField("vector", vector, similarity));
260+
doc.add(new KnnFloatVectorField("vector", vector, similarity));
261261
});
262262
final IndexDiskUsageStats stats = IndexDiskUsageAnalyzer.analyze(testShardId(), lastCommit(dir), () -> {});
263263
logger.info("--> stats {}", stats);
@@ -520,7 +520,7 @@ static void addRandomTermVectors(Document doc) {
520520
static void addRandomKnnVectors(Document doc) {
521521
int numFields = randomFrom(1, 3);
522522
for (int f = 0; f < numFields; f++) {
523-
doc.add(new XKnnFloatVectorField("knnvector-" + f, randomVector(DEFAULT_VECTOR_DIMENSION)));
523+
doc.add(new KnnFloatVectorField("knnvector-" + f, randomVector(DEFAULT_VECTOR_DIMENSION)));
524524
}
525525
}
526526

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.apache.lucene.codecs.KnnVectorsFormat;
1515
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
1616
import org.apache.lucene.document.BinaryDocValuesField;
17+
import org.apache.lucene.document.KnnByteVectorField;
18+
import org.apache.lucene.document.KnnFloatVectorField;
1719
import org.apache.lucene.index.IndexableField;
1820
import org.apache.lucene.index.VectorEncoding;
1921
import org.apache.lucene.search.FieldExistsQuery;
@@ -256,9 +258,9 @@ public void testIndexedVector() throws Exception {
256258

257259
List<IndexableField> fields = doc1.rootDoc().getFields("field");
258260
assertEquals(1, fields.size());
259-
assertThat(fields.get(0), instanceOf(XKnnFloatVectorField.class));
261+
assertThat(fields.get(0), instanceOf(KnnFloatVectorField.class));
260262

261-
XKnnFloatVectorField vectorField = (XKnnFloatVectorField) fields.get(0);
263+
KnnFloatVectorField vectorField = (KnnFloatVectorField) fields.get(0);
262264
assertArrayEquals("Parsed vector is not equal to original.", vector, vectorField.vectorValue(), 0.001f);
263265
assertEquals(similarity.function, vectorField.fieldType().vectorSimilarityFunction());
264266
}
@@ -280,9 +282,9 @@ public void testIndexedByteVector() throws Exception {
280282

281283
List<IndexableField> fields = doc1.rootDoc().getFields("field");
282284
assertEquals(1, fields.size());
283-
assertThat(fields.get(0), instanceOf(XKnnByteVectorField.class));
285+
assertThat(fields.get(0), instanceOf(KnnByteVectorField.class));
284286

285-
XKnnByteVectorField vectorField = (XKnnByteVectorField) fields.get(0);
287+
KnnByteVectorField vectorField = (KnnByteVectorField) fields.get(0);
286288
vectorField.vectorValue();
287289
assertArrayEquals(
288290
"Parsed vector is not equal to original.",
@@ -514,7 +516,7 @@ public void testDocumentsWithIncorrectDims() throws Exception {
514516

515517
/**
516518
* Test that max dimensions limit for float dense_vector field
517-
* is 2048 as defined by {@link XKnnFloatVectorField}
519+
* is 2048 as defined by {@link DenseVectorFieldMapper#MAX_DIMS_COUNT}
518520
*/
519521
public void testMaxDimsFloatVector() throws IOException {
520522
final int dims = 2048;
@@ -531,8 +533,8 @@ public void testMaxDimsFloatVector() throws IOException {
531533
List<IndexableField> fields = doc1.rootDoc().getFields("field");
532534

533535
assertEquals(1, fields.size());
534-
assertThat(fields.get(0), instanceOf(XKnnFloatVectorField.class));
535-
XKnnFloatVectorField vectorField = (XKnnFloatVectorField) fields.get(0);
536+
assertThat(fields.get(0), instanceOf(KnnFloatVectorField.class));
537+
KnnFloatVectorField vectorField = (KnnFloatVectorField) fields.get(0);
536538
assertEquals(dims, vectorField.fieldType().vectorDimension());
537539
assertEquals(VectorEncoding.FLOAT32, vectorField.fieldType().vectorEncoding());
538540
assertEquals(similarity.function, vectorField.fieldType().vectorSimilarityFunction());
@@ -541,7 +543,7 @@ public void testMaxDimsFloatVector() throws IOException {
541543

542544
/**
543545
* Test that max dimensions limit for byte dense_vector field
544-
* is 2048 as defined by {@link XKnnByteVectorField}
546+
* is 2048 as defined by {@link KnnByteVectorField}
545547
*/
546548
public void testMaxDimsByteVector() throws IOException {
547549
final int dims = 2048;
@@ -565,8 +567,8 @@ public void testMaxDimsByteVector() throws IOException {
565567
List<IndexableField> fields = doc1.rootDoc().getFields("field");
566568

567569
assertEquals(1, fields.size());
568-
assertThat(fields.get(0), instanceOf(XKnnByteVectorField.class));
569-
XKnnByteVectorField vectorField = (XKnnByteVectorField) fields.get(0);
570+
assertThat(fields.get(0), instanceOf(KnnByteVectorField.class));
571+
KnnByteVectorField vectorField = (KnnByteVectorField) fields.get(0);
570572
assertEquals(dims, vectorField.fieldType().vectorDimension());
571573
assertEquals(VectorEncoding.BYTE, vectorField.fieldType().vectorEncoding());
572574
assertEquals(similarity.function, vectorField.fieldType().vectorSimilarityFunction());

server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.lucene.document.Document;
1111
import org.apache.lucene.document.Field;
1212
import org.apache.lucene.document.IntPoint;
13+
import org.apache.lucene.document.KnnFloatVectorField;
1314
import org.apache.lucene.document.StringField;
1415
import org.apache.lucene.index.FloatVectorValues;
1516
import org.apache.lucene.index.IndexReader;
@@ -27,7 +28,6 @@
2728
import org.apache.lucene.util.automaton.CompiledAutomaton;
2829
import org.apache.lucene.util.automaton.RegExp;
2930
import org.elasticsearch.core.IOUtils;
30-
import org.elasticsearch.index.mapper.vectors.XKnnFloatVectorField;
3131
import org.elasticsearch.search.internal.ContextIndexSearcher;
3232
import org.elasticsearch.tasks.TaskCancelledException;
3333
import org.elasticsearch.test.ESTestCase;
@@ -67,7 +67,7 @@ private static void indexRandomDocuments(RandomIndexWriter w, int numDocs) throw
6767
Document doc = new Document();
6868
doc.add(new StringField(STRING_FIELD_NAME, "a".repeat(i), Field.Store.NO));
6969
doc.add(new IntPoint(POINT_FIELD_NAME, i, i + 1));
70-
doc.add(new XKnnFloatVectorField(KNN_FIELD_NAME, new float[] { 1.0f, 0.5f, 42.0f }));
70+
doc.add(new KnnFloatVectorField(KNN_FIELD_NAME, new float[] { 1.0f, 0.5f, 42.0f }));
7171
w.addDocument(doc);
7272
}
7373
}

server/src/test/java/org/elasticsearch/search/vectors/VectorSimilarityQueryTests.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.apache.lucene.search.Weight;
2626
import org.apache.lucene.store.Directory;
2727
import org.elasticsearch.common.lucene.LuceneTests;
28-
import org.elasticsearch.index.mapper.vectors.XKnnFloatVectorField;
2928
import org.elasticsearch.test.ESTestCase;
3029

3130
import java.io.IOException;
@@ -41,7 +40,7 @@ public void testSimpleEuclidean() throws Exception {
4140
try (Directory d = newDirectory()) {
4241
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
4342
Document document = new Document();
44-
KnnFloatVectorField vectorField = new XKnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
43+
KnnFloatVectorField vectorField = new KnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
4544
document.add(vectorField);
4645
w.addDocument(document);
4746
vectorField.setVectorValue(new float[] { 2, 1, 1 });
@@ -82,7 +81,7 @@ public void testEuclideanInvariant() throws Exception {
8281
Supplier<float[]> vectorValue = () -> new float[] { randomFloat(), randomFloat(), randomFloat(), randomFloat() };
8382
try (Directory d = newDirectory()) {
8483
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
85-
KnnFloatVectorField vectorField = new XKnnFloatVectorField(fieldName, vectorValue.get());
84+
KnnFloatVectorField vectorField = new KnnFloatVectorField(fieldName, vectorValue.get());
8685
Document document = new Document();
8786
document.add(vectorField);
8887
for (int i = 0; i < n; i++) {
@@ -113,7 +112,7 @@ public void testSimpleCosine() throws IOException {
113112
try (Directory d = newDirectory()) {
114113
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
115114
Document document = new Document();
116-
KnnFloatVectorField vectorField = new XKnnFloatVectorField(
115+
KnnFloatVectorField vectorField = new KnnFloatVectorField(
117116
"float_vector",
118117
new float[] { 1, 1, 1 },
119118
VectorSimilarityFunction.COSINE
@@ -158,7 +157,7 @@ public void testCosineInvariant() throws Exception {
158157
Supplier<float[]> vectorValue = () -> new float[] { randomFloat(), randomFloat(), randomFloat(), randomFloat() };
159158
try (Directory d = newDirectory()) {
160159
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
161-
KnnFloatVectorField vectorField = new XKnnFloatVectorField(fieldName, vectorValue.get(), VectorSimilarityFunction.COSINE);
160+
KnnFloatVectorField vectorField = new KnnFloatVectorField(fieldName, vectorValue.get(), VectorSimilarityFunction.COSINE);
162161
Document document = new Document();
163162
document.add(vectorField);
164163
for (int i = 0; i < n; i++) {
@@ -192,7 +191,7 @@ public void testExplain() throws IOException {
192191
try (Directory d = newDirectory()) {
193192
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
194193
Document document = new Document();
195-
KnnFloatVectorField vectorField = new XKnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
194+
KnnFloatVectorField vectorField = new KnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
196195
document.add(vectorField);
197196
w.addDocument(document);
198197
vectorField.setVectorValue(new float[] { 2, 1, 1 });

0 commit comments

Comments
 (0)