Skip to content

Commit c7e48a0

Browse files
committed
Add first version of BlockDocValuesReader for dense_vector
1 parent 6439422 commit c7e48a0

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.index.BinaryDocValues;
1313
import org.apache.lucene.index.DocValues;
14+
import org.apache.lucene.index.FloatVectorValues;
1415
import org.apache.lucene.index.LeafReaderContext;
1516
import org.apache.lucene.index.NumericDocValues;
1617
import org.apache.lucene.index.SortedDocValues;
@@ -504,6 +505,80 @@ public String toString() {
504505
}
505506
}
506507

508+
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
509+
private final String fieldName;
510+
511+
public DenseVectorBlockLoader(String fieldName) {
512+
this.fieldName = fieldName;
513+
}
514+
515+
@Override
516+
public Builder builder(BlockFactory factory, int expectedCount) {
517+
return factory.doubles(expectedCount);
518+
}
519+
520+
@Override
521+
public AllReader reader(LeafReaderContext context) throws IOException {
522+
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
523+
if (floatVectorValues != null) {
524+
return new FloatVectorValuesBlockReader(floatVectorValues);
525+
}
526+
return new ConstantNullsReader();
527+
}
528+
}
529+
530+
private static class FloatVectorValuesBlockReader extends BlockDocValuesReader {
531+
private final FloatVectorValues floatVectorValues;
532+
private int docId = -1;
533+
534+
FloatVectorValuesBlockReader(FloatVectorValues floatVectorValues) {
535+
this.floatVectorValues = floatVectorValues;
536+
}
537+
538+
@Override
539+
public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException {
540+
try (BlockLoader.DoubleBuilder builder = factory.doubles(docs.count())) {
541+
for (int i = 0; i < docs.count(); i++) {
542+
int doc = docs.get(i);
543+
if (doc < docId) {
544+
throw new IllegalStateException("docs within same block must be in order");
545+
}
546+
read(doc, builder);
547+
}
548+
return builder.build();
549+
}
550+
}
551+
552+
@Override
553+
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
554+
read(docId, (DoubleBuilder) builder);
555+
}
556+
557+
private void read(int doc, DoubleBuilder builder) throws IOException {
558+
float[] floats = floatVectorValues.vectorValue(doc);
559+
if (floats != null) {
560+
builder.beginPositionEntry();
561+
for (float aFloat : floats) {
562+
builder.appendDouble(aFloat);
563+
}
564+
builder.endPositionEntry();
565+
} else {
566+
builder.appendNull();
567+
}
568+
docId = doc;
569+
}
570+
571+
@Override
572+
public int docId() {
573+
return docId;
574+
}
575+
576+
@Override
577+
public String toString() {
578+
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
579+
}
580+
}
581+
507582
public static class BytesRefsFromOrdsBlockLoader extends DocValuesBlockLoader {
508583
private final String fieldName;
509584

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.lucene.search.join.BitSetProducer;
3636
import org.apache.lucene.util.BitUtil;
3737
import org.apache.lucene.util.BytesRef;
38+
import org.apache.lucene.util.NumericUtils;
3839
import org.apache.lucene.util.VectorUtil;
3940
import org.elasticsearch.common.ParsingException;
4041
import org.elasticsearch.common.xcontent.support.XContentMapValues;
@@ -51,6 +52,9 @@
5152
import org.elasticsearch.index.fielddata.FieldDataContext;
5253
import org.elasticsearch.index.fielddata.IndexFieldData;
5354
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
55+
import org.elasticsearch.index.mapper.BlockDocValuesReader;
56+
import org.elasticsearch.index.mapper.BlockLoader;
57+
import org.elasticsearch.index.mapper.BlockSourceReader;
5458
import org.elasticsearch.index.mapper.DocumentParserContext;
5559
import org.elasticsearch.index.mapper.FieldMapper;
5660
import org.elasticsearch.index.mapper.MappedFieldType;
@@ -2306,6 +2310,13 @@ int getVectorDimensions() {
23062310
ElementType getElementType() {
23072311
return elementType;
23082312
}
2313+
2314+
@Override
2315+
public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
2316+
// TODO Check synthetic source etc (see NumberFieldMapper)
2317+
2318+
return new BlockDocValuesReader.DenseVectorBlockLoader(name());
2319+
}
23092320
}
23102321

23112322
private final IndexOptions indexOptions;
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql;
9+
10+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
11+
import org.elasticsearch.action.index.IndexRequest;
12+
import org.elasticsearch.action.support.WriteRequest;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
15+
import org.hamcrest.CoreMatchers;
16+
import org.junit.Before;
17+
18+
import java.util.HashMap;
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
23+
24+
public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
25+
26+
private static Map<Integer, List<Float>> DOC_VALUES = new HashMap<>();
27+
static {
28+
DOC_VALUES.put(1, List.of(1.0f, 2.0f, 3.0f));
29+
DOC_VALUES.put(2, List.of(4.0f, 5.0f, 6.0f));
30+
DOC_VALUES.put(3, List.of(7.0f, 8.0f, 9.0f));
31+
DOC_VALUES.put(4, List.of(10.0f, 11.0f, 12.0f));
32+
DOC_VALUES.put(5, List.of(13.0f, 14.0f, 15.0f));
33+
DOC_VALUES.put(6, List.of(16.0f, 17.0f, 18.0f));
34+
}
35+
36+
public void testRetrieveFieldType() {
37+
var query = """
38+
FROM test
39+
""";
40+
41+
try (var resp = run(query)) {
42+
assertColumnNames(resp.columns(), List.of("id", "vector"));
43+
assertColumnTypes(resp.columns(), List.of("integer", "dense_vector"));
44+
}
45+
}
46+
47+
@SuppressWarnings("unchecked")
48+
public void testRetrieveDenseVectorFieldData() {
49+
var query = """
50+
FROM test
51+
| SORT id ASC
52+
""";
53+
54+
try (var resp = run(query)) {
55+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
56+
DOC_VALUES.forEach((id, vector) -> {
57+
var values = valuesList.get(id - 1);
58+
assertEquals(id.intValue(), ((Long) values.get(0)).intValue());
59+
List<Double> scores = (List<Double>) values.get(1);
60+
assertEquals(vector.size(), scores.size());
61+
for (int i = 0; i < vector.size(); i++) {
62+
assertEquals((float)vector.get(i), scores.get(i).floatValue(), 0F);
63+
}
64+
});
65+
}
66+
}
67+
68+
public void testSorted() {
69+
var query = """
70+
FROM test
71+
""";
72+
73+
try (var resp = run(query)) {
74+
assertColumnNames(resp.columns(), List.of("id", "vector"));
75+
assertColumnTypes(resp.columns(), List.of("integer", "dense_vector"));
76+
}
77+
}
78+
79+
@Before
80+
public void setup() {
81+
var indexName = "test";
82+
var client = client().admin().indices();
83+
var mapping = """
84+
"id": integer,
85+
"vector": {
86+
"type": "dense_vector",
87+
"index_options": {
88+
"type": "hnsw"
89+
}
90+
}
91+
""";
92+
var CreateRequest = client.prepareCreate(indexName)
93+
.setSettings(Settings.builder().put("index.number_of_shards", 1))
94+
.setMapping(mapping);
95+
assertAcked(CreateRequest);
96+
97+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
98+
for (var entry : DOC_VALUES.entrySet()) {
99+
bulkRequestBuilder.add(
100+
new IndexRequest(indexName).id(entry.getKey().toString()).source("id", entry.getKey(), "vector", entry.getValue())
101+
);
102+
}
103+
104+
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
105+
ensureYellow(indexName);
106+
}
107+
}

0 commit comments

Comments
 (0)