Skip to content

Commit 52e076c

Browse files
Add Google Model Garden Integration
1 parent 43841a5 commit 52e076c

File tree

6 files changed

+129
-22
lines changed

6 files changed

+129
-22
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
256256
serviceSettings.location(),
257257
serviceSettings.projectId(),
258258
serviceSettings.modelId(),
259+
serviceSettings.endpointId(),
260+
serviceSettings.isDedicatedEndpoint(),
259261
serviceSettings.dimensionsSetByUser(),
260262
serviceSettings.maxInputTokens(),
261263
embeddingSize,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ public class GoogleVertexAiServiceFields {
1313

1414
public static final String PROJECT_ID = "project_id";
1515

16+
public static final String ENDPOINT_ID = "endpoint_id";
17+
18+
public static final String IS_DEDICATED_ENDPOINT = "is_dedicated_endpoint";
19+
1620
/**
1721
* In `us-central-1` the max input size is `250`, but in every other region it's `5` according
1822
* to these docs: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
8181
serviceSettings
8282
);
8383
try {
84-
this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
84+
this.uri = buildUri(
85+
serviceSettings.location(),
86+
serviceSettings.projectId(),
87+
serviceSettings.modelId(),
88+
serviceSettings.endpointId(),
89+
serviceSettings.isDedicatedEndpoint()
90+
);
8591
} catch (URISyntaxException e) {
8692
throw new RuntimeException(e);
8793
}
@@ -134,20 +140,51 @@ public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String,
134140
return visitor.create(this, taskSettings);
135141
}
136142

137-
public static URI buildUri(String location, String projectId, String modelId) throws URISyntaxException {
138-
return new URIBuilder().setScheme("https")
139-
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
140-
.setPathSegments(
141-
GoogleVertexAiUtils.V1,
142-
GoogleVertexAiUtils.PROJECTS,
143-
projectId,
144-
GoogleVertexAiUtils.LOCATIONS,
145-
location,
146-
GoogleVertexAiUtils.PUBLISHERS,
147-
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
148-
GoogleVertexAiUtils.MODELS,
149-
format("%s:%s", modelId, GoogleVertexAiUtils.PREDICT)
150-
)
151-
.build();
143+
public static URI buildUri(String location, String projectId, String modelId, String endpointId, Boolean isDedicatedEndpoint)
144+
throws URISyntaxException {
145+
if (modelId != null) {
146+
return new URIBuilder().setScheme("https")
147+
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
148+
.setPathSegments(
149+
GoogleVertexAiUtils.V1,
150+
GoogleVertexAiUtils.PROJECTS,
151+
projectId,
152+
GoogleVertexAiUtils.LOCATIONS,
153+
location,
154+
GoogleVertexAiUtils.PUBLISHERS,
155+
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
156+
GoogleVertexAiUtils.MODELS,
157+
format("%s:%s", modelId, GoogleVertexAiUtils.PREDICT)
158+
)
159+
.build();
160+
} else if (endpointId != null) {
161+
// TODO: Decide if we should require isDedicatedEndpoint or default it to true
162+
if (isDedicatedEndpoint == null || isDedicatedEndpoint) {
163+
return new URI(
164+
format(
165+
"https://%s.%s-%s.prediction.vertexai.goog/v1/projects/%s/locations/%s/endpoints/%s:predict",
166+
endpointId,
167+
location,
168+
projectId,
169+
projectId,
170+
location,
171+
endpointId
172+
)
173+
);
174+
} else {
175+
return new URI(
176+
format(
177+
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/endpoints/%s:predict",
178+
location,
179+
projectId,
180+
location,
181+
endpointId
182+
)
183+
);
184+
}
185+
} else {
186+
throw new IllegalArgumentException("Either modelId or endpointId must be provided");
187+
}
188+
152189
}
153190
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@
3434
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3535
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
3636
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
37+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
3738
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
3839
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
40+
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.ENDPOINT_ID;
41+
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.IS_DEDICATED_ENDPOINT;
3942
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
4043
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
4144

@@ -56,7 +59,9 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
5659

5760
String location = extractRequiredString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
5861
String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
59-
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
62+
String model = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
63+
String endpointId = extractOptionalString(map, ENDPOINT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
64+
Boolean isDedicatedEndpoint = extractOptionalBoolean(map, IS_DEDICATED_ENDPOINT, validationException);
6065
Integer maxInputTokens = extractOptionalPositiveInteger(
6166
map,
6267
MAX_INPUT_TOKENS,
@@ -93,6 +98,14 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
9398
}
9499
}
95100

101+
if ((model == null && endpointId == null) || (model != null && endpointId != null)) {
102+
validationException.addValidationError("Either model or endpoint_id must be set, but not both.");
103+
}
104+
105+
if (endpointId == null && isDedicatedEndpoint != null) {
106+
validationException.addValidationError("is_dedicated_endpoint can only be set when endpoint_id is set.");
107+
}
108+
96109
if (validationException.validationErrors().isEmpty() == false) {
97110
throw validationException;
98111
}
@@ -101,6 +114,8 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
101114
location,
102115
projectId,
103116
model,
117+
endpointId,
118+
isDedicatedEndpoint,
104119
dimensionsSetByUser,
105120
maxInputTokens,
106121
dims,
@@ -115,6 +130,10 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
115130

116131
private final String modelId;
117132

133+
private final String endpointId;
134+
135+
private final Boolean isDedicatedEndpoint;
136+
118137
private final Integer dims;
119138

120139
private final SimilarityMeasure similarity;
@@ -128,6 +147,8 @@ public GoogleVertexAiEmbeddingsServiceSettings(
128147
String location,
129148
String projectId,
130149
String modelId,
150+
String endpointId,
151+
Boolean isDedicatedEndpoint,
131152
Boolean dimensionsSetByUser,
132153
@Nullable Integer maxInputTokens,
133154
@Nullable Integer dims,
@@ -137,6 +158,8 @@ public GoogleVertexAiEmbeddingsServiceSettings(
137158
this.location = location;
138159
this.projectId = projectId;
139160
this.modelId = modelId;
161+
this.endpointId = endpointId;
162+
this.isDedicatedEndpoint = isDedicatedEndpoint;
140163
this.dimensionsSetByUser = dimensionsSetByUser;
141164
this.maxInputTokens = maxInputTokens;
142165
this.dims = dims;
@@ -147,7 +170,9 @@ public GoogleVertexAiEmbeddingsServiceSettings(
147170
public GoogleVertexAiEmbeddingsServiceSettings(StreamInput in) throws IOException {
148171
this.location = in.readString();
149172
this.projectId = in.readString();
150-
this.modelId = in.readString();
173+
this.modelId = in.readOptionalString();
174+
this.endpointId = in.readOptionalString();
175+
this.isDedicatedEndpoint = in.readOptionalBoolean();
151176
this.dimensionsSetByUser = in.readBoolean();
152177
this.maxInputTokens = in.readOptionalVInt();
153178
this.dims = in.readOptionalVInt();
@@ -169,6 +194,14 @@ public String modelId() {
169194
return modelId;
170195
}
171196

197+
public String endpointId() {
198+
return endpointId;
199+
}
200+
201+
public Boolean isDedicatedEndpoint() {
202+
return isDedicatedEndpoint;
203+
}
204+
172205
public Boolean dimensionsSetByUser() {
173206
return dimensionsSetByUser;
174207
}
@@ -222,7 +255,9 @@ public TransportVersion getMinimalSupportedVersion() {
222255
public void writeTo(StreamOutput out) throws IOException {
223256
out.writeString(location);
224257
out.writeString(projectId);
225-
out.writeString(modelId);
258+
out.writeOptionalString(modelId);
259+
out.writeOptionalString(endpointId);
260+
out.writeOptionalBoolean(isDedicatedEndpoint);
226261
out.writeBoolean(dimensionsSetByUser);
227262
out.writeOptionalVInt(maxInputTokens);
228263
out.writeOptionalVInt(dims);
@@ -235,6 +270,8 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
235270
builder.field(LOCATION, location);
236271
builder.field(PROJECT_ID, projectId);
237272
builder.field(MODEL_ID, modelId);
273+
builder.field(ENDPOINT_ID, endpointId); // TODO: Transport verison?
274+
builder.field(IS_DEDICATED_ENDPOINT, isDedicatedEndpoint);
238275

239276
if (maxInputTokens != null) {
240277
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
@@ -261,6 +298,8 @@ public boolean equals(Object object) {
261298
return Objects.equals(location, that.location)
262299
&& Objects.equals(projectId, that.projectId)
263300
&& Objects.equals(modelId, that.modelId)
301+
&& Objects.equals(endpointId, that.endpointId)
302+
&& Objects.equals(isDedicatedEndpoint, that.isDedicatedEndpoint)
264303
&& Objects.equals(dims, that.dims)
265304
&& similarity == that.similarity
266305
&& Objects.equals(maxInputTokens, that.maxInputTokens)
@@ -270,6 +309,17 @@ public boolean equals(Object object) {
270309

271310
@Override
272311
public int hashCode() {
273-
return Objects.hash(location, projectId, modelId, dims, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser);
312+
return Objects.hash(
313+
location,
314+
projectId,
315+
modelId,
316+
endpointId,
317+
isDedicatedEndpoint,
318+
dims,
319+
similarity,
320+
maxInputTokens,
321+
rateLimitSettings,
322+
dimensionsSetByUser
323+
);
274324
}
275325
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public void testBuildUri() throws URISyntaxException {
3131
var projectId = "project";
3232
var modelId = "model";
3333

34-
URI uri = GoogleVertexAiEmbeddingsModel.buildUri(location, projectId, modelId);
34+
URI uri = GoogleVertexAiEmbeddingsModel.buildUri(location, projectId, modelId, null, null);
3535

3636
assertThat(
3737
uri,
@@ -98,7 +98,7 @@ public static GoogleVertexAiEmbeddingsModel createModel(
9898
TaskType.TEXT_EMBEDDING,
9999
"service",
100100
uri,
101-
new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, false, null, null, null, null),
101+
new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, null, null, false, null, null, null, null),
102102
new GoogleVertexAiEmbeddingsTaskSettings(Boolean.FALSE, null),
103103
new GoogleVertexAiSecretSettings(new SecureString(serviceAccountJson.toCharArray()))
104104
);
@@ -117,6 +117,8 @@ public static GoogleVertexAiEmbeddingsModel createModel(
117117
randomAlphaOfLength(8),
118118
randomAlphaOfLength(8),
119119
modelId,
120+
null,
121+
null,
120122
false,
121123
null,
122124
null,
@@ -138,6 +140,8 @@ public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullabl
138140
"location",
139141
"projectId",
140142
modelId,
143+
null,
144+
null,
141145
false,
142146
null,
143147
null,
@@ -163,6 +167,8 @@ public static GoogleVertexAiEmbeddingsModel createRandomizedModel(
163167
randomAlphaOfLength(8),
164168
randomAlphaOfLength(8),
165169
modelId,
170+
null,
171+
null,
166172
false,
167173
null,
168174
null,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettingsTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public void testFromMap_Request_CreatesSettingsCorrectly() {
6060
location,
6161
projectId,
6262
model,
63+
null, // TODO: Randomize this
64+
null, // TODO: Randomize this
6365
dimensionsSetByUser,
6466
maxInputTokens,
6567
dims,
@@ -75,6 +77,8 @@ public void testToXContent_WritesAllValues() throws IOException {
7577
"location",
7678
"projectId",
7779
"modelId",
80+
null, // TODO: Set this value
81+
null, // TODO: Set this value
7882
true,
7983
10,
8084
10,
@@ -107,6 +111,8 @@ public void testFilteredXContentObject_WritesAllValues_ExceptDimensionsSetByUser
107111
"location",
108112
"projectId",
109113
"modelId",
114+
null, // TODO: Set this value
115+
null, // TODO: Set this value
110116
true,
111117
10,
112118
10,
@@ -162,6 +168,8 @@ private static GoogleVertexAiEmbeddingsServiceSettings createRandom() {
162168
randomAlphaOfLength(10),
163169
randomAlphaOfLength(10),
164170
randomAlphaOfLength(10),
171+
null, // TODO: Randomize this value
172+
null, // TODO: Randomize this value
165173
randomBoolean(),
166174
randomFrom(new Integer[] { null, randomNonNegativeInt() }),
167175
randomFrom(new Integer[] { null, randomNonNegativeInt() }),

0 commit comments

Comments
 (0)