From 0b268bcac91bf79cd165e5fea65cbcc49d5d77d2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 11 Aug 2025 16:02:20 -0400 Subject: [PATCH 1/2] Add support for dimensions in request --- .../GoogleVertexAiEmbeddingsModel.java | 2 +- .../GoogleVertexAiEmbeddingsRequest.java | 10 +- ...GoogleVertexAiEmbeddingsRequestEntity.java | 18 ++- ...eVertexAiEmbeddingsRequestEntityTests.java | 66 ++++++++-- .../GoogleVertexAiEmbeddingsRequestTests.java | 117 ++++++++++++++++-- 5 files changed, 180 insertions(+), 33 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java index 66031f7e5475d..010be42dd872f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java @@ -67,7 +67,7 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google } // Should only be used directly for testing - GoogleVertexAiEmbeddingsModel( + public GoogleVertexAiEmbeddingsModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java index bf506a08d8268..c11d5b15ee3c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java @@ -49,8 +49,14 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(model.nonStreamingUri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings())) - .getBytes(StandardCharsets.UTF_8) + Strings.toString( + new GoogleVertexAiEmbeddingsRequestEntity( + truncationResult.input(), + inputType, + model.getTaskSettings(), + model.getServiceSettings() + ) + ).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java index d3aa6688faa45..508fc45ada5b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java @@ -10,6 +10,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import java.io.IOException; @@ -21,13 +22,15 @@ public record GoogleVertexAiEmbeddingsRequestEntity( List inputs, InputType inputType, - GoogleVertexAiEmbeddingsTaskSettings taskSettings + GoogleVertexAiEmbeddingsTaskSettings taskSettings, + GoogleVertexAiEmbeddingsServiceSettings serviceSettings ) implements ToXContentObject { private static final String INSTANCES_FIELD = "instances"; private static final String CONTENT_FIELD = "content"; private static final String PARAMETERS_FIELD = "parameters"; private static final String AUTO_TRUNCATE_FIELD = "autoTruncate"; + private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality"; private static final String TASK_TYPE_FIELD = "task_type"; private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION"; @@ -38,6 +41,7 @@ public record GoogleVertexAiEmbeddingsRequestEntity( public GoogleVertexAiEmbeddingsRequestEntity { Objects.requireNonNull(inputs); Objects.requireNonNull(taskSettings); + Objects.requireNonNull(serviceSettings); } @Override @@ -62,15 +66,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endArray(); - if (taskSettings.autoTruncate() != null) { - builder.startObject(PARAMETERS_FIELD); - { + builder.startObject(PARAMETERS_FIELD); + { + if (taskSettings.autoTruncate() != null) { builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate()); } - builder.endObject(); + if (serviceSettings.dimensionsSetByUser()) { + builder.field(OUTPUT_DIMENSIONALITY_FIELD, serviceSettings.dimensions()); + } } builder.endObject(); + builder.endObject(); + return builder; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java index 41d26b4464b01..e5540310f7806 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import java.io.IOException; @@ -26,7 +27,8 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc var entity = new GoogleVertexAiEmbeddingsRequestEntity( List.of("abc"), null, - new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING) + new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -42,17 +44,19 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc } ], "parameters": { - "autoTruncate": true + "autoTruncate": true, + "outputDimensionality": 10 } } """)); } - public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException { + public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields() throws IOException { var entity = new GoogleVertexAiEmbeddingsRequestEntity( List.of("abc"), InputType.INTERNAL_INGEST, - new GoogleVertexAiEmbeddingsTaskSettings(null, null) + new GoogleVertexAiEmbeddingsTaskSettings(null, null), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -66,13 +70,45 @@ public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNo "content": "abc", "task_type": "RETRIEVAL_DOCUMENT" } - ] + ], + "parameters": { + } + } + """)); + } + + public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields_DimensionsSetByUserFalse() throws IOException { + var entity = new GoogleVertexAiEmbeddingsRequestEntity( + List.of("abc"), + InputType.INTERNAL_INGEST, + new GoogleVertexAiEmbeddingsTaskSettings(null, null), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, 10, null, null) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "instances": [ + { + "content": "abc", + "task_type": "RETRIEVAL_DOCUMENT" + } + ], + "parameters": {} } """)); } public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException { - var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null, new GoogleVertexAiEmbeddingsTaskSettings(false, null)); + var entity = new GoogleVertexAiEmbeddingsRequestEntity( + List.of("abc"), + null, + new GoogleVertexAiEmbeddingsTaskSettings(false, null), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null) + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -96,7 +132,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO var entity = new GoogleVertexAiEmbeddingsRequestEntity( List.of("abc", "def"), InputType.INTERNAL_SEARCH, - new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING) + new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -116,7 +153,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO } ], "parameters": { - "autoTruncate": true + "autoTruncate": true, + "outputDimensionality": 10 } } """)); @@ -126,7 +164,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteInputTypeIfNotD var entity = new GoogleVertexAiEmbeddingsRequestEntity( List.of("abc", "def"), null, - new GoogleVertexAiEmbeddingsTaskSettings(true, null) + new GoogleVertexAiEmbeddingsTaskSettings(true, null), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -154,7 +193,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI var entity = new GoogleVertexAiEmbeddingsRequestEntity( List.of("abc", "def"), null, - new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION) + new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION), + new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -172,12 +212,14 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI "content": "def", "task_type": "CLASSIFICATION" } - ] + ], + "parameters": { + } } """)); } public void testToXContent_ThrowsIfTaskSettingsIsNull() { - expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null)); + expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null, null)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestTests.java index 89e2490329687..7b11d593b0007 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestTests.java @@ -9,16 +9,21 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import java.io.IOException; import java.util.List; @@ -49,12 +54,15 @@ public void testCreateRequest_WithoutDimensionsSet_And_WithoutAutoTruncateSet_An assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, aMapWithSize(2)); if (InputType.isSpecified(inputType)) { var convertedInputType = convertToString(inputType); - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType))))); + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of())) + ); } else { - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input"))))); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of()))); } } @@ -96,6 +104,43 @@ public void testCreateRequest_WithAutoTruncateSet() throws IOException { } } + public void testCreateRequest_WithDimensions() throws IOException { + var model = "model"; + var input = "input"; + var inputType = InputTypeTests.randomWithNull(); + + var request = createRequestWithDimensions(model, input, 10, inputType); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + if (InputType.isSpecified(inputType)) { + var convertedInputType = convertToString(inputType); + assertThat( + requestMap, + is( + Map.of( + "instances", + List.of(Map.of("content", "input", "task_type", convertedInputType)), + "parameters", + Map.of("outputDimensionality", 10) + ) + ) + ); + } else { + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of("outputDimensionality", 10))) + ); + } + } + public void testCreateRequest_WithTaskSettingsInputTypeSet() throws IOException { var model = "model"; var input = "input"; @@ -111,12 +156,15 @@ public void testCreateRequest_WithTaskSettingsInputTypeSet() throws IOException assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, aMapWithSize(2)); if (InputType.isSpecified(inputType)) { var convertedInputType = convertToString(inputType); - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType))))); + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of())) + ); } else { - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input"))))); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of()))); } } @@ -136,15 +184,21 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, aMapWithSize(2)); if (InputType.isSpecified(requestInputType)) { var convertedInputType = convertToString(requestInputType); - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType))))); + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of())) + ); } else if (InputType.isSpecified(taskSettingsInputType)) { var convertedInputType = convertToString(taskSettingsInputType); - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType))))); + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of())) + ); } else { - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input"))))); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of()))); } } @@ -164,13 +218,16 @@ public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, aMapWithSize(2)); if (InputType.isSpecified(inputType)) { var convertedInputType = convertToString(inputType); - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType))))); + assertThat( + requestMap, + is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType)), "parameters", Map.of())) + ); } else { - assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab"))))); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")), "parameters", Map.of()))); } } @@ -191,6 +248,40 @@ private static GoogleVertexAiEmbeddingsRequest createRequest( ); } + private static GoogleVertexAiEmbeddingsRequest createRequestWithDimensions( + String modelId, + String input, + int dimensions, + @Nullable InputType requestInputType + ) { + + var embeddingsModel = new GoogleVertexAiEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new GoogleVertexAiEmbeddingsServiceSettings( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + modelId, + true, + null, + dimensions, + null, + null + ), + new GoogleVertexAiEmbeddingsTaskSettings(null, null), + null, + new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray())) + ); + + return new GoogleVertexAiEmbeddingsWithoutAuthRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + requestInputType, + embeddingsModel + ); + } + /** * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} */ From e1fc5face704bc4588ae5b5d47bedd22857858c7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:05:54 -0400 Subject: [PATCH 2/2] Update docs/changelog/132689.yaml --- docs/changelog/132689.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/132689.yaml diff --git a/docs/changelog/132689.yaml b/docs/changelog/132689.yaml new file mode 100644 index 0000000000000..80e65644abe7a --- /dev/null +++ b/docs/changelog/132689.yaml @@ -0,0 +1,5 @@ +pr: 132689 +summary: Add support for dimensions in google vertex ai request +area: Machine Learning +type: enhancement +issues: []