Skip to content

Commit 0165233

Browse files
[ML] Add support for dimensions in google vertex ai request (elastic#132689)
* Add support for dimensions in request * Update docs/changelog/132689.yaml
1 parent 26ffd7f commit 0165233

File tree

6 files changed

+185
-33
lines changed

6 files changed

+185
-33
lines changed

docs/changelog/132689.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132689
2+
summary: Add support for dimensions in google vertex ai request
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
6767
}
6868

6969
// Should only be used directly for testing
70-
GoogleVertexAiEmbeddingsModel(
70+
public GoogleVertexAiEmbeddingsModel(
7171
String inferenceEntityId,
7272
TaskType taskType,
7373
String service,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ public HttpRequest createHttpRequest() {
4949
HttpPost httpPost = new HttpPost(model.nonStreamingUri());
5050

5151
ByteArrayEntity byteEntity = new ByteArrayEntity(
52-
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings()))
53-
.getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(
53+
new GoogleVertexAiEmbeddingsRequestEntity(
54+
truncationResult.input(),
55+
inputType,
56+
model.getTaskSettings(),
57+
model.getServiceSettings()
58+
)
59+
).getBytes(StandardCharsets.UTF_8)
5460
);
5561

5662
httpPost.setEntity(byteEntity);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.inference.InputType;
1111
import org.elasticsearch.xcontent.ToXContentObject;
1212
import org.elasticsearch.xcontent.XContentBuilder;
13+
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
1314
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
1415

1516
import java.io.IOException;
@@ -21,13 +22,15 @@
2122
public record GoogleVertexAiEmbeddingsRequestEntity(
2223
List<String> inputs,
2324
InputType inputType,
24-
GoogleVertexAiEmbeddingsTaskSettings taskSettings
25+
GoogleVertexAiEmbeddingsTaskSettings taskSettings,
26+
GoogleVertexAiEmbeddingsServiceSettings serviceSettings
2527
) implements ToXContentObject {
2628

2729
private static final String INSTANCES_FIELD = "instances";
2830
private static final String CONTENT_FIELD = "content";
2931
private static final String PARAMETERS_FIELD = "parameters";
3032
private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
33+
private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality";
3134
private static final String TASK_TYPE_FIELD = "task_type";
3235

3336
private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
@@ -38,6 +41,7 @@ public record GoogleVertexAiEmbeddingsRequestEntity(
3841
public GoogleVertexAiEmbeddingsRequestEntity {
3942
Objects.requireNonNull(inputs);
4043
Objects.requireNonNull(taskSettings);
44+
Objects.requireNonNull(serviceSettings);
4145
}
4246

4347
@Override
@@ -62,15 +66,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6266

6367
builder.endArray();
6468

65-
if (taskSettings.autoTruncate() != null) {
66-
builder.startObject(PARAMETERS_FIELD);
67-
{
69+
builder.startObject(PARAMETERS_FIELD);
70+
{
71+
if (taskSettings.autoTruncate() != null) {
6872
builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
6973
}
70-
builder.endObject();
74+
if (serviceSettings.dimensionsSetByUser()) {
75+
builder.field(OUTPUT_DIMENSIONALITY_FIELD, serviceSettings.dimensions());
76+
}
7177
}
7278
builder.endObject();
7379

80+
builder.endObject();
81+
7482
return builder;
7583
}
7684

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xcontent.XContentBuilder;
1414
import org.elasticsearch.xcontent.XContentFactory;
1515
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
1617
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
1718

1819
import java.io.IOException;
@@ -26,7 +27,8 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
2627
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
2728
List.of("abc"),
2829
null,
29-
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
30+
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
31+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
3032
);
3133

3234
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -42,17 +44,19 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
4244
}
4345
],
4446
"parameters": {
45-
"autoTruncate": true
47+
"autoTruncate": true,
48+
"outputDimensionality": 10
4649
}
4750
}
4851
"""));
4952
}
5053

51-
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
54+
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields() throws IOException {
5255
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
5356
List.of("abc"),
5457
InputType.INTERNAL_INGEST,
55-
new GoogleVertexAiEmbeddingsTaskSettings(null, null)
58+
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
59+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
5660
);
5761

5862
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -66,13 +70,45 @@ public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNo
6670
"content": "abc",
6771
"task_type": "RETRIEVAL_DOCUMENT"
6872
}
69-
]
73+
],
74+
"parameters": {
75+
}
76+
}
77+
"""));
78+
}
79+
80+
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields_DimensionsSetByUserFalse() throws IOException {
81+
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
82+
List.of("abc"),
83+
InputType.INTERNAL_INGEST,
84+
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
85+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, 10, null, null)
86+
);
87+
88+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
89+
entity.toXContent(builder, null);
90+
String xContentResult = Strings.toString(builder);
91+
92+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
93+
{
94+
"instances": [
95+
{
96+
"content": "abc",
97+
"task_type": "RETRIEVAL_DOCUMENT"
98+
}
99+
],
100+
"parameters": {}
70101
}
71102
"""));
72103
}
73104

74105
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
75-
var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null, new GoogleVertexAiEmbeddingsTaskSettings(false, null));
106+
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
107+
List.of("abc"),
108+
null,
109+
new GoogleVertexAiEmbeddingsTaskSettings(false, null),
110+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
111+
);
76112

77113
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
78114
entity.toXContent(builder, null);
@@ -96,7 +132,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
96132
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
97133
List.of("abc", "def"),
98134
InputType.INTERNAL_SEARCH,
99-
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
135+
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
136+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
100137
);
101138

102139
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -116,7 +153,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
116153
}
117154
],
118155
"parameters": {
119-
"autoTruncate": true
156+
"autoTruncate": true,
157+
"outputDimensionality": 10
120158
}
121159
}
122160
"""));
@@ -126,7 +164,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteInputTypeIfNotD
126164
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
127165
List.of("abc", "def"),
128166
null,
129-
new GoogleVertexAiEmbeddingsTaskSettings(true, null)
167+
new GoogleVertexAiEmbeddingsTaskSettings(true, null),
168+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
130169
);
131170

132171
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -154,7 +193,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
154193
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
155194
List.of("abc", "def"),
156195
null,
157-
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION)
196+
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION),
197+
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
158198
);
159199

160200
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -172,12 +212,14 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
172212
"content": "def",
173213
"task_type": "CLASSIFICATION"
174214
}
175-
]
215+
],
216+
"parameters": {
217+
}
176218
}
177219
"""));
178220
}
179221

180222
public void testToXContent_ThrowsIfTaskSettingsIsNull() {
181-
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null));
223+
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null, null));
182224
}
183225
}

0 commit comments

Comments
 (0)