Skip to content

Commit d6c2464

Browse files
committed
google ai studio
1 parent 7f20d32 commit d6c2464

File tree

9 files changed

+376
-17
lines changed

9 files changed

+376
-17
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ public HttpRequest createHttpRequest() {
4444
new GoogleAiStudioEmbeddingsRequestEntity(
4545
truncationResult.input(),
4646
model.getServiceSettings().modelId(),
47-
model.getServiceSettings().dimensions()
47+
model.getServiceSettings().dimensions(),
48+
model.inputType()
4849
)
4950
).getBytes(StandardCharsets.UTF_8)
5051
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequestEntity.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.external.request.googleaistudio;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.InputType;
1112
import org.elasticsearch.xcontent.ToXContentObject;
1213
import org.elasticsearch.xcontent.XContentBuilder;
1314

@@ -16,10 +17,14 @@
1617
import java.util.Objects;
1718

1819
import static org.elasticsearch.core.Strings.format;
20+
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
1921

20-
public record GoogleAiStudioEmbeddingsRequestEntity(List<String> inputs, String model, @Nullable Integer dimensions)
21-
implements
22-
ToXContentObject {
22+
public record GoogleAiStudioEmbeddingsRequestEntity(
23+
List<String> inputs,
24+
String model,
25+
@Nullable Integer dimensions,
26+
@Nullable InputType inputType
27+
) implements ToXContentObject {
2328

2429
private static final String REQUESTS_FIELD = "requests";
2530
private static final String MODEL_FIELD = "model";
@@ -29,6 +34,12 @@ public record GoogleAiStudioEmbeddingsRequestEntity(List<String> inputs, String
2934
private static final String PARTS_FIELD = "parts";
3035
private static final String TEXT_FIELD = "text";
3136

37+
public static final String TASK_TYPE_FIELD = "taskType";
38+
private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
39+
private static final String CLUSTERING_TASK_TYPE = "CLUSTERING";
40+
private static final String RETRIEVAL_DOCUMENT_TASK_TYPE = "RETRIEVAL_DOCUMENT";
41+
private static final String RETRIEVAL_QUERY_TASK_TYPE = "RETRIEVAL_QUERY";
42+
3243
private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality";
3344

3445
public GoogleAiStudioEmbeddingsRequestEntity {
@@ -67,12 +78,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6778
builder.field(OUTPUT_DIMENSIONALITY_FIELD, dimensions);
6879
}
6980

81+
if (inputType != null) {
82+
builder.field(TASK_TYPE_FIELD, convertToString(inputType));
83+
}
84+
7085
builder.endObject();
7186
}
7287

7388
builder.endArray();
89+
7490
builder.endObject();
7591

7692
return builder;
7793
}
94+
95+
// default for testing
96+
static String convertToString(InputType inputType) {
97+
return switch (inputType) {
98+
case INGEST, INTERNAL_INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
99+
case SEARCH, INTERNAL_SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
100+
case CLASSIFICATION -> CLASSIFICATION_TASK_TYPE;
101+
case CLUSTERING -> CLUSTERING_TASK_TYPE;
102+
default -> {
103+
assert false : invalidInputTypeMessage(inputType);
104+
yield null;
105+
}
106+
};
107+
}
78108
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,10 +800,13 @@ public static String useChatCompletionUrlMessage(Model model) {
800800
);
801801
}
802802

803-
static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(InputType.INTERNAL_INGEST, InputType.INTERNAL_SEARCH);
803+
public static final EnumSet<InputType> VALID_INTERNAL_INPUT_TYPE_VALUES = EnumSet.of(
804+
InputType.INTERNAL_INGEST,
805+
InputType.INTERNAL_SEARCH
806+
);
804807

805808
public static void validateInputTypeIsUnspecifiedOrInternal(InputType inputType, ValidationException validationException) {
806-
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INPUT_TYPE_VALUES.contains(inputType) == false) {
809+
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INTERNAL_INPUT_TYPE_VALUES.contains(inputType) == false) {
807810
validationException.addValidationError(
808811
Strings.format("Invalid value [%s] received. [%s] is not allowed", inputType, "input_type")
809812
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,9 @@ protected void doInfer(
291291
);
292292
action.execute(inputs, timeout, listener);
293293
} else if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
294+
var overriddenModel = GoogleAiStudioEmbeddingsModel.of(embeddingsModel, inputType);
294295
var requestManager = new GoogleAiStudioEmbeddingsRequestManager(
295-
embeddingsModel,
296+
overriddenModel,
296297
getServiceComponents().truncator(),
297298
getServiceComponents().threadPool()
298299
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
package org.elasticsearch.xpack.inference.services.googleaistudio.embeddings;
99

1010
import org.apache.http.client.utils.URIBuilder;
11+
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.core.Nullable;
13+
import org.elasticsearch.core.Strings;
1214
import org.elasticsearch.inference.ChunkingSettings;
1315
import org.elasticsearch.inference.EmptyTaskSettings;
16+
import org.elasticsearch.inference.InputType;
1417
import org.elasticsearch.inference.ModelConfigurations;
1518
import org.elasticsearch.inference.ModelSecrets;
1619
import org.elasticsearch.inference.TaskSettings;
@@ -22,13 +25,54 @@
2225

2326
import java.net.URI;
2427
import java.net.URISyntaxException;
28+
import java.util.EnumSet;
2529
import java.util.Map;
30+
import java.util.Objects;
2631

2732
import static org.elasticsearch.core.Strings.format;
33+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.VALID_INTERNAL_INPUT_TYPE_VALUES;
2834

2935
public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel {
36+
static final String MODEL_ID_WITH_TASK_TYPE = "embedding-001";
37+
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(
38+
InputType.INGEST,
39+
InputType.SEARCH,
40+
InputType.CLASSIFICATION,
41+
InputType.CLUSTERING,
42+
InputType.INTERNAL_INGEST,
43+
InputType.INTERNAL_SEARCH
44+
);
45+
46+
public static GoogleAiStudioEmbeddingsModel of(GoogleAiStudioEmbeddingsModel model, InputType inputType) {
47+
var modelId = model.getServiceSettings().modelId();
48+
// InputType is only allowed when model=embedding-001 https://ai.google.dev/api/embeddings?authuser=5#EmbedContentRequest
49+
ValidationException validationException = new ValidationException();
50+
if (Objects.equals(model.getServiceSettings().modelId(), MODEL_ID_WITH_TASK_TYPE) == false) {
51+
// this model does not accept input type parameter
52+
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_INTERNAL_INPUT_TYPE_VALUES.contains(inputType) == false) {
53+
// throw validation exception if ingest type is specified
54+
validationException.addValidationError(
55+
Strings.format("Invalid value [%s] received. [%s] is not allowed for model [%s]", inputType, "input_type", modelId)
56+
);
57+
} else {
58+
return model;
59+
}
60+
} else {
61+
if (inputType != null && inputType != InputType.UNSPECIFIED && VALID_REQUEST_VALUES.contains(inputType) == false) {
62+
validationException.addValidationError(
63+
Strings.format("Invalid value [%s] received. [%s] is not allowed", inputType, "input_type")
64+
);
65+
}
66+
}
67+
68+
if (validationException.validationErrors().isEmpty() == false) {
69+
throw validationException;
70+
}
71+
return new GoogleAiStudioEmbeddingsModel(model, model.getServiceSettings(), inputType == InputType.UNSPECIFIED ? null : inputType);
72+
}
3073

3174
private URI uri;
75+
private InputType inputType;
3276

3377
public GoogleAiStudioEmbeddingsModel(
3478
String inferenceEntityId,
@@ -55,6 +99,20 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
5599
super(model, serviceSettings);
56100
}
57101

102+
public GoogleAiStudioEmbeddingsModel(
103+
GoogleAiStudioEmbeddingsModel model,
104+
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
105+
InputType inputType
106+
) {
107+
super(model, serviceSettings);
108+
this.inputType = inputType;
109+
try {
110+
this.uri = buildUri(serviceSettings.modelId());
111+
} catch (URISyntaxException e) {
112+
throw new RuntimeException(e);
113+
}
114+
}
115+
58116
// Should only be used directly for testing
59117
GoogleAiStudioEmbeddingsModel(
60118
String inferenceEntityId,
@@ -77,6 +135,30 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
77135
}
78136
}
79137

138+
// Should only be used directly for testing
139+
GoogleAiStudioEmbeddingsModel(
140+
String inferenceEntityId,
141+
TaskType taskType,
142+
String service,
143+
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
144+
TaskSettings taskSettings,
145+
ChunkingSettings chunkingSettings,
146+
@Nullable DefaultSecretSettings secrets,
147+
@Nullable InputType inputType
148+
) {
149+
super(
150+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
151+
new ModelSecrets(secrets),
152+
serviceSettings
153+
);
154+
this.inputType = inputType;
155+
try {
156+
this.uri = buildUri(serviceSettings.modelId());
157+
} catch (URISyntaxException e) {
158+
throw new RuntimeException(e);
159+
}
160+
}
161+
80162
// Should only be used directly for testing
81163
GoogleAiStudioEmbeddingsModel(
82164
String inferenceEntityId,
@@ -136,6 +218,10 @@ public URI uri() {
136218
return uri;
137219
}
138220

221+
public InputType inputType() {
222+
return inputType;
223+
}
224+
139225
public static URI buildUri(String model) throws URISyntaxException {
140226
return new URIBuilder().setScheme("https")
141227
.setHost(GoogleAiStudioUtils.HOST_SUFFIX)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestEntityTests.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.external.request.googleaistudio.embeddings;
99

1010
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.inference.InputType;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xcontent.XContentBuilder;
1314
import org.elasticsearch.xcontent.XContentFactory;
@@ -22,7 +23,7 @@
2223
public class GoogleAiStudioEmbeddingsRequestEntityTests extends ESTestCase {
2324

2425
public void testXContent_SingleRequest_WritesDimensionsIfDefined() throws IOException {
25-
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", 8);
26+
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", 8, null);
2627

2728
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
2829
entity.toXContent(builder, null);
@@ -48,7 +49,7 @@ public void testXContent_SingleRequest_WritesDimensionsIfDefined() throws IOExce
4849
}
4950

5051
public void testXContent_SingleRequest_DoesNotWriteDimensionsIfNull() throws IOException {
51-
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", null);
52+
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", null, null);
5253

5354
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
5455
entity.toXContent(builder, null);
@@ -73,7 +74,7 @@ public void testXContent_SingleRequest_DoesNotWriteDimensionsIfNull() throws IOE
7374
}
7475

7576
public void testXContent_MultipleRequests_WritesDimensionsIfDefined() throws IOException {
76-
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", 8);
77+
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", 8, null);
7778

7879
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
7980
entity.toXContent(builder, null);
@@ -110,7 +111,7 @@ public void testXContent_MultipleRequests_WritesDimensionsIfDefined() throws IOE
110111
}
111112

112113
public void testXContent_MultipleRequests_DoesNotWriteDimensionsIfNull() throws IOException {
113-
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", null);
114+
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", null, null);
114115

115116
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
116117
entity.toXContent(builder, null);
@@ -143,4 +144,30 @@ public void testXContent_MultipleRequests_DoesNotWriteDimensionsIfNull() throws
143144
}
144145
"""));
145146
}
147+
148+
public void testXContent_SingleRequest_WritesInputTypeIfDefined() throws IOException {
149+
var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", 8, InputType.INTERNAL_INGEST);
150+
151+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
152+
entity.toXContent(builder, null);
153+
String xContentResult = Strings.toString(builder);
154+
155+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
156+
{
157+
"requests": [
158+
{
159+
"model": "models/model",
160+
"content": {
161+
"parts": [
162+
{
163+
"text": "abc"
164+
}
165+
]
166+
},
167+
"taskType": "RETRIEVAL_DOCUMENT"
168+
}
169+
]
170+
}
171+
"""));
172+
}
146173
}

0 commit comments

Comments
 (0)