Skip to content

Commit fe65554

Browse files
authored
Merge branch 'main' into remove_resolved_issues
2 parents b158373 + 7a34391 commit fe65554

File tree

8 files changed

+241
-55
lines changed

8 files changed

+241
-55
lines changed

docs/changelog/132689.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132689
2+
summary: Add support for dimensions in google vertex ai request
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import java.io.IOException;
4040

41+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
4142
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
4243

4344
/**
@@ -540,6 +541,11 @@ public AllReader reader(LeafReaderContext context) throws IOException {
540541
case FLOAT -> {
541542
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
542543
if (floatVectorValues != null) {
544+
if (fieldType.isNormalized()) {
545+
NumericDocValues magnitudeDocValues = context.reader()
546+
.getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX);
547+
return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues);
548+
}
543549
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
544550
}
545551
}
@@ -584,6 +590,9 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build
584590
}
585591

586592
private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
593+
assert vectorValues.dimension() == dimensions
594+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + vectorValues.dimension();
595+
587596
if (iterator.docID() > doc) {
588597
builder.appendNull();
589598
} else if (iterator.docID() == doc || iterator.advance(doc) == doc) {
@@ -611,8 +620,6 @@ private static class FloatDenseVectorValuesBlockReader extends DenseVectorValues
611620

612621
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
613622
float[] floats = vectorValues.vectorValue(iterator.index());
614-
assert floats.length == dimensions
615-
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
616623
for (float aFloat : floats) {
617624
builder.appendFloat(aFloat);
618625
}
@@ -624,15 +631,45 @@ public String toString() {
624631
}
625632
}
626633

634+
private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
635+
private final NumericDocValues magnitudeDocValues;
636+
637+
FloatDenseVectorNormalizedValuesBlockReader(
638+
FloatVectorValues floatVectorValues,
639+
int dimensions,
640+
NumericDocValues magnitudeDocValues
641+
) {
642+
super(floatVectorValues, dimensions);
643+
this.magnitudeDocValues = magnitudeDocValues;
644+
}
645+
646+
@Override
647+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
648+
float magnitude = 1.0f;
649+
// If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a
650+
// stored magnitude for all docs
651+
if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) {
652+
magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
653+
}
654+
float[] floats = vectorValues.vectorValue(iterator.index());
655+
for (float aFloat : floats) {
656+
builder.appendFloat(aFloat * magnitude);
657+
}
658+
}
659+
660+
@Override
661+
public String toString() {
662+
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
663+
}
664+
}
665+
627666
private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
628667
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
629668
super(floatVectorValues, dimensions);
630669
}
631670

632671
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
633672
byte[] bytes = vectorValues.vectorValue(iterator.index());
634-
assert bytes.length == dimensions
635-
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length;
636673
for (byte aFloat : bytes) {
637674
builder.appendFloat(aFloat);
638675
}

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,30 @@
2424

2525
import java.io.IOException;
2626
import java.util.ArrayList;
27+
import java.util.Arrays;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Locale;
3031
import java.util.Map;
3132
import java.util.Set;
33+
import java.util.stream.Collectors;
3234

3335
import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING;
3436
import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC;
3537
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
3638

3739
public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
3840

39-
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Set.of(
40-
"int8_hnsw",
41-
"hnsw",
42-
"int4_hnsw",
43-
"bbq_hnsw",
44-
"int8_flat",
45-
"int4_flat",
46-
"bbq_flat",
47-
"flat"
48-
);
49-
public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat");
41+
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
42+
.filter(DenseVectorFieldMapper.VectorIndexType::isEnabled)
43+
.map(v -> v.getName().toLowerCase(Locale.ROOT))
44+
.collect(Collectors.toSet());
45+
46+
public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
47+
.filter(t -> t.isEnabled() && t.isQuantized() == false)
48+
.map(v -> v.getName().toLowerCase(Locale.ROOT))
49+
.collect(Collectors.toSet());
50+
5051
public static final float DELTA = 1e-7F;
5152

5253
private final ElementType elementType;
@@ -57,15 +58,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
5758
@ParametersFactory
5859
public static Iterable<Object[]> parameters() throws Exception {
5960
List<Object[]> params = new ArrayList<>();
60-
List<DenseVectorFieldMapper.VectorSimilarity> similarities = List.of(
61-
DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT,
62-
DenseVectorFieldMapper.VectorSimilarity.L2_NORM,
63-
DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT
64-
);
6561

6662
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) {
67-
// Test all similarities for element types
68-
for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) {
63+
// Test all similarities
64+
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
6965
params.add(new Object[] { elementType, similarity, true, false });
7066
}
7167

@@ -74,6 +70,7 @@ public static Iterable<Object[]> parameters() throws Exception {
7470
// No indexing, synthetic source
7571
params.add(new Object[] { elementType, null, false, true });
7672
}
73+
7774
return params;
7875
}
7976

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
6767
}
6868

6969
// Should only be used directly for testing
70-
GoogleVertexAiEmbeddingsModel(
70+
public GoogleVertexAiEmbeddingsModel(
7171
String inferenceEntityId,
7272
TaskType taskType,
7373
String service,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ public HttpRequest createHttpRequest() {
4949
HttpPost httpPost = new HttpPost(model.nonStreamingUri());
5050

5151
ByteArrayEntity byteEntity = new ByteArrayEntity(
52-
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings()))
53-
.getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(
53+
new GoogleVertexAiEmbeddingsRequestEntity(
54+
truncationResult.input(),
55+
inputType,
56+
model.getTaskSettings(),
57+
model.getServiceSettings()
58+
)
59+
).getBytes(StandardCharsets.UTF_8)
5460
);
5561

5662
httpPost.setEntity(byteEntity);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.inference.InputType;
1111
import org.elasticsearch.xcontent.ToXContentObject;
1212
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
1314
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
1415

1516
import java.io.IOException;
@@ -21,13 +22,15 @@
2122
public record GoogleVertexAiEmbeddingsRequestEntity(
2223
List<String> inputs,
2324
InputType inputType,
24-
GoogleVertexAiEmbeddingsTaskSettings taskSettings
25+
GoogleVertexAiEmbeddingsTaskSettings taskSettings,
26+
GoogleVertexAiEmbeddingsServiceSettings serviceSettings
2527
) implements ToXContentObject {
2628

2729
private static final String INSTANCES_FIELD = "instances";
2830
private static final String CONTENT_FIELD = "content";
2931
private static final String PARAMETERS_FIELD = "parameters";
3032
private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
33+
private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality";
3134
private static final String TASK_TYPE_FIELD = "task_type";
3235

3336
private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
@@ -38,6 +41,7 @@ public record GoogleVertexAiEmbeddingsRequestEntity(
3841
public GoogleVertexAiEmbeddingsRequestEntity {
3942
Objects.requireNonNull(inputs);
4043
Objects.requireNonNull(taskSettings);
44+
Objects.requireNonNull(serviceSettings);
4145
}
4246

4347
@Override
@@ -62,15 +66,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6266

6367
builder.endArray();
6468

65-
if (taskSettings.autoTruncate() != null) {
66-
builder.startObject(PARAMETERS_FIELD);
67-
{
69+
builder.startObject(PARAMETERS_FIELD);
70+
{
71+
if (taskSettings.autoTruncate() != null) {
6872
builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
6973
}
70-
builder.endObject();
74+
if (serviceSettings.dimensionsSetByUser()) {
75+
builder.field(OUTPUT_DIMENSIONALITY_FIELD, serviceSettings.dimensions());
76+
}
7177
}
7278
builder.endObject();
7379

80+
builder.endObject();
81+
7482
return builder;
7583
}
7684

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414
import org.elasticsearch.xcontent.XContentFactory;
1515
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
1617
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
1718

1819
import java.io.IOException;
@@ -26,7 +27,8 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
2627
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
2728
List.of("abc"),
2829
null,
29-
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
30+
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
31+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
3032
);
3133

3234
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -42,17 +44,19 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
4244
}
4345
],
4446
"parameters": {
45-
"autoTruncate": true
47+
"autoTruncate": true,
48+
"outputDimensionality": 10
4649
}
4750
}
4851
"""));
4952
}
5053

51-
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
54+
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields() throws IOException {
5255
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
5356
List.of("abc"),
5457
InputType.INTERNAL_INGEST,
55-
new GoogleVertexAiEmbeddingsTaskSettings(null, null)
58+
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
59+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
5660
);
5761

5862
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -66,13 +70,45 @@ public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNo
6670
"content": "abc",
6771
"task_type": "RETRIEVAL_DOCUMENT"
6872
}
69-
]
73+
],
74+
"parameters": {
75+
}
76+
}
77+
"""));
78+
}
79+
80+
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields_DimensionsSetByUserFalse() throws IOException {
81+
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
82+
List.of("abc"),
83+
InputType.INTERNAL_INGEST,
84+
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
85+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, 10, null, null)
86+
);
87+
88+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
89+
entity.toXContent(builder, null);
90+
String xContentResult = Strings.toString(builder);
91+
92+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
93+
{
94+
"instances": [
95+
{
96+
"content": "abc",
97+
"task_type": "RETRIEVAL_DOCUMENT"
98+
}
99+
],
100+
"parameters": {}
70101
}
71102
"""));
72103
}
73104

74105
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
75-
var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null, new GoogleVertexAiEmbeddingsTaskSettings(false, null));
106+
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
107+
List.of("abc"),
108+
null,
109+
new GoogleVertexAiEmbeddingsTaskSettings(false, null),
110+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
111+
);
76112

77113
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
78114
entity.toXContent(builder, null);
@@ -96,7 +132,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
96132
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
97133
List.of("abc", "def"),
98134
InputType.INTERNAL_SEARCH,
99-
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
135+
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
136+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
100137
);
101138

102139
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -116,7 +153,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
116153
}
117154
],
118155
"parameters": {
119-
"autoTruncate": true
156+
"autoTruncate": true,
157+
"outputDimensionality": 10
120158
}
121159
}
122160
"""));
@@ -126,7 +164,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteInputTypeIfNotD
126164
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
127165
List.of("abc", "def"),
128166
null,
129-
new GoogleVertexAiEmbeddingsTaskSettings(true, null)
167+
new GoogleVertexAiEmbeddingsTaskSettings(true, null),
168+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
130169
);
131170

132171
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -154,7 +193,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
154193
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
155194
List.of("abc", "def"),
156195
null,
157-
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION)
196+
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION),
197+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
158198
);
159199

160200
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -172,12 +212,14 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
172212
"content": "def",
173213
"task_type": "CLASSIFICATION"
174214
}
175-
]
215+
],
216+
"parameters": {
217+
}
176218
}
177219
"""));
178220
}
179221

180222
public void testToXContent_ThrowsIfTaskSettingsIsNull() {
181-
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null));
223+
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null, null));
182224
}
183225
}

0 commit comments

Comments
 (0)