Skip to content

Commit c211d83

Browse files
Fixing rerank and chat completions
1 parent 5f13d28 commit c211d83

File tree

11 files changed

+145
-94
lines changed

11 files changed

+145
-94
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public void cancelled() {
137137
private void respondUsingUtilityThread(HttpResponse response, HttpRequest request, ActionListener<HttpResult> listener) {
138138
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> {
139139
try {
140-
listener.onResponse(HttpResult.create(settings.getMaxResponseSize(), response));
140+
listener.onResponse(HttpResult.create(settings.getMaxResponseSize(), response, request));
141141
} catch (Exception e) {
142142
throttlerManager.warn(
143143
logger,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
import org.elasticsearch.core.Streams;
1313
import org.elasticsearch.rest.RestStatus;
1414
import org.elasticsearch.xpack.inference.common.SizeLimitInputStream;
15+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1516

1617
import java.io.ByteArrayOutputStream;
1718
import java.io.IOException;
1819
import java.io.InputStream;
1920
import java.util.Objects;
2021

21-
public record HttpResult(HttpResponse response, byte[] body) {
22+
public record HttpResult(HttpResponse response, byte[] body, HttpRequest request) {
2223

23-
public static HttpResult create(ByteSizeValue maxResponseSize, HttpResponse response) throws IOException {
24-
return new HttpResult(response, limitBody(maxResponseSize, response));
24+
public static HttpResult create(ByteSizeValue maxResponseSize, HttpResponse response, HttpRequest request) throws IOException {
25+
return new HttpResult(response, limitBody(maxResponseSize, response), request);
2526
}
2627

2728
private static byte[] limitBody(ByteSizeValue maxResponseSize, HttpResponse response) throws IOException {
@@ -43,6 +44,7 @@ private static byte[] limitBody(ByteSizeValue maxResponseSize, HttpResponse resp
4344
public HttpResult {
4445
Objects.requireNonNull(response);
4546
Objects.requireNonNull(body);
47+
Objects.requireNonNull(request);
4648
}
4749

4850
public boolean isBodyEmpty() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.ExceptionsHelper;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.rest.RestStatus;
14+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1415

1516
import java.io.ByteArrayOutputStream;
1617
import java.util.concurrent.Flow;
@@ -21,7 +22,7 @@ public boolean isSuccessfulResponse() {
2122
return RestStatus.isSuccessful(response.getStatusLine().getStatusCode());
2223
}
2324

24-
public Flow.Publisher<HttpResult> toHttpResult() {
25+
public Flow.Publisher<HttpResult> toHttpResult(HttpRequest httpRequest) {
2526
return subscriber -> body().subscribe(new Flow.Subscriber<>() {
2627
@Override
2728
public void onSubscribe(Flow.Subscription subscription) {
@@ -30,7 +31,7 @@ public void onSubscribe(Flow.Subscription subscription) {
3031

3132
@Override
3233
public void onNext(byte[] item) {
33-
subscriber.onNext(new HttpResult(response(), item));
34+
subscriber.onNext(new HttpResult(response(), item, httpRequest));
3435
}
3536

3637
@Override
@@ -45,7 +46,7 @@ public void onComplete() {
4546
});
4647
}
4748

48-
public void readFullResponse(ActionListener<HttpResult> fullResponse) {
49+
public void readFullResponse(HttpRequest httpRequest, ActionListener<HttpResult> fullResponse) {
4950
var stream = new ByteArrayOutputStream();
5051
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
5152
body.subscribe(new Flow.Subscriber<>() {
@@ -69,7 +70,7 @@ public void onError(Throwable throwable) {
6970

7071
@Override
7172
public void onComplete() {
72-
fullResponse.onResponse(new HttpResult(response, stream.toByteArray()));
73+
fullResponse.onResponse(new HttpResult(response, stream.toByteArray(), httpRequest));
7374
}
7475
});
7576
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
115115

116116
try {
117117
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
118-
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
118+
var httpRequest = request.createHttpRequest();
119+
httpClient.stream(httpRequest, context, retryableListener.delegateFailure((l, r) -> {
119120
if (r.isSuccessfulResponse()) {
120-
l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
121+
l.onResponse(responseHandler.parseResult(request, r.toHttpResult(httpRequest)));
121122
} else {
122-
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
123+
r.readFullResponse(httpRequest, l.delegateFailureAndWrap((ll, httpResult) -> {
123124
try {
124125
responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
125126
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,12 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
238238
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
239239

240240
return new CustomServiceSettings(
241-
similarityToUse,
242-
embeddingSize,
243-
serviceSettings.getMaxInputTokens(),
241+
new CustomServiceSettings.TextEmbeddingSettings(
242+
similarityToUse,
243+
embeddingSize,
244+
serviceSettings.getMaxInputTokens(),
245+
serviceSettings.elementType()
246+
),
244247
serviceSettings.getUrl(),
245248
serviceSettings.getHeaders(),
246249
serviceSettings.getQueryParameters(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import org.elasticsearch.common.ValidationException;
1414
import org.elasticsearch.common.io.stream.StreamInput;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.io.stream.Writeable;
1617
import org.elasticsearch.core.Nullable;
1718
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1819
import org.elasticsearch.inference.ModelConfigurations;
1920
import org.elasticsearch.inference.ServiceSettings;
2021
import org.elasticsearch.inference.SimilarityMeasure;
2122
import org.elasticsearch.inference.TaskType;
23+
import org.elasticsearch.xcontent.ToXContentFragment;
2224
import org.elasticsearch.xcontent.ToXContentObject;
2325
import org.elasticsearch.xcontent.XContentBuilder;
2426
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -66,9 +68,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
6668
public static CustomServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context, TaskType taskType) {
6769
ValidationException validationException = new ValidationException();
6870

69-
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
70-
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
71-
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
71+
var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException);
7272

7373
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
7474

@@ -134,9 +134,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
134134
}
135135

136136
return new CustomServiceSettings(
137-
similarity,
138-
dims,
139-
maxInputTokens,
137+
textEmbeddingSettings,
140138
url,
141139
stringHeaders,
142140
queryParams,
@@ -147,9 +145,59 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
147145
);
148146
}
149147

150-
private final SimilarityMeasure similarity;
151-
private final Integer dimensions;
152-
private final Integer maxInputTokens;
148+
public record TextEmbeddingSettings(
149+
@Nullable SimilarityMeasure similarityMeasure,
150+
@Nullable Integer dimensions,
151+
@Nullable Integer maxInputTokens,
152+
@Nullable DenseVectorFieldMapper.ElementType elementType
153+
) implements ToXContentFragment, Writeable {
154+
155+
public static final TextEmbeddingSettings EMPTY = new TextEmbeddingSettings(null, null, null, null);
156+
157+
public static TextEmbeddingSettings fromMap(Map<String, Object> map, TaskType taskType, ValidationException validationException) {
158+
if (taskType != TaskType.TEXT_EMBEDDING) {
159+
return EMPTY;
160+
}
161+
162+
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
163+
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
164+
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
165+
return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT);
166+
}
167+
168+
public TextEmbeddingSettings(StreamInput in) throws IOException {
169+
this(
170+
in.readOptionalEnum(SimilarityMeasure.class),
171+
in.readOptionalVInt(),
172+
in.readOptionalVInt(),
173+
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class)
174+
);
175+
}
176+
177+
@Override
178+
public void writeTo(StreamOutput out) throws IOException {
179+
out.writeOptionalEnum(similarityMeasure);
180+
out.writeOptionalVInt(dimensions);
181+
out.writeOptionalVInt(maxInputTokens);
182+
out.writeOptionalEnum(elementType);
183+
}
184+
185+
@Override
186+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
187+
if (similarityMeasure != null) {
188+
builder.field(SIMILARITY, similarityMeasure);
189+
}
190+
if (dimensions != null) {
191+
builder.field(DIMENSIONS, dimensions);
192+
}
193+
if (maxInputTokens != null) {
194+
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
195+
}
196+
return builder;
197+
}
198+
}
199+
200+
private final TextEmbeddingSettings textEmbeddingSettings;
153201
private final String url;
154202
private final Map<String, String> headers;
155203
private final QueryParameters queryParameters;
@@ -159,9 +207,7 @@ public static CustomServiceSettings fromMap(Map<String, Object> map, Configurati
159207
private final ErrorResponseParser errorParser;
160208

161209
public CustomServiceSettings(
162-
@Nullable SimilarityMeasure similarity,
163-
@Nullable Integer dimensions,
164-
@Nullable Integer maxInputTokens,
210+
@Nullable TextEmbeddingSettings textEmbeddingSettings,
165211
String url,
166212
@Nullable Map<String, String> headers,
167213
@Nullable QueryParameters queryParameters,
@@ -170,9 +216,7 @@ public CustomServiceSettings(
170216
@Nullable RateLimitSettings rateLimitSettings,
171217
ErrorResponseParser errorParser
172218
) {
173-
this.similarity = similarity;
174-
this.dimensions = dimensions;
175-
this.maxInputTokens = maxInputTokens;
219+
this.textEmbeddingSettings = textEmbeddingSettings == null ? TextEmbeddingSettings.EMPTY : textEmbeddingSettings;
176220
this.url = Objects.requireNonNull(url);
177221
this.headers = Collections.unmodifiableMap(Objects.requireNonNullElse(headers, Map.of()));
178222
this.queryParameters = Objects.requireNonNullElse(queryParameters, QueryParameters.EMPTY);
@@ -183,9 +227,7 @@ public CustomServiceSettings(
183227
}
184228

185229
public CustomServiceSettings(StreamInput in) throws IOException {
186-
similarity = in.readOptionalEnum(SimilarityMeasure.class);
187-
dimensions = in.readOptionalVInt();
188-
maxInputTokens = in.readOptionalVInt();
230+
textEmbeddingSettings = new TextEmbeddingSettings(in);
189231
url = in.readString();
190232
headers = in.readImmutableMap(StreamInput::readString);
191233
queryParameters = new QueryParameters(in);
@@ -203,21 +245,21 @@ public String modelId() {
203245

204246
@Override
205247
public SimilarityMeasure similarity() {
206-
return similarity;
248+
return textEmbeddingSettings.similarityMeasure;
207249
}
208250

209251
@Override
210252
public Integer dimensions() {
211-
return dimensions;
253+
return textEmbeddingSettings.dimensions;
212254
}
213255

214256
@Override
215257
public DenseVectorFieldMapper.ElementType elementType() {
216-
return DenseVectorFieldMapper.ElementType.FLOAT;
258+
return textEmbeddingSettings.elementType;
217259
}
218260

219261
public Integer getMaxInputTokens() {
220-
return maxInputTokens;
262+
return textEmbeddingSettings.maxInputTokens;
221263
}
222264

223265
public String getUrl() {
@@ -270,15 +312,7 @@ public XContentBuilder toXContentFragment(XContentBuilder builder, Params params
270312

271313
@Override
272314
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
273-
if (similarity != null) {
274-
builder.field(SIMILARITY, similarity);
275-
}
276-
if (dimensions != null) {
277-
builder.field(DIMENSIONS, dimensions);
278-
}
279-
if (maxInputTokens != null) {
280-
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
281-
}
315+
textEmbeddingSettings.toXContent(builder, params);
282316
builder.field(URL, url);
283317

284318
if (headers.isEmpty() == false) {
@@ -317,9 +351,7 @@ public TransportVersion getMinimalSupportedVersion() {
317351

318352
@Override
319353
public void writeTo(StreamOutput out) throws IOException {
320-
out.writeOptionalEnum(similarity);
321-
out.writeOptionalVInt(dimensions);
322-
out.writeOptionalVInt(maxInputTokens);
354+
textEmbeddingSettings.writeTo(out);
323355
out.writeString(url);
324356
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
325357
queryParameters.writeTo(out);
@@ -334,9 +366,7 @@ public boolean equals(Object o) {
334366
if (this == o) return true;
335367
if (o == null || getClass() != o.getClass()) return false;
336368
CustomServiceSettings that = (CustomServiceSettings) o;
337-
return Objects.equals(similarity, that.similarity)
338-
&& Objects.equals(dimensions, that.dimensions)
339-
&& Objects.equals(maxInputTokens, that.maxInputTokens)
369+
return Objects.equals(textEmbeddingSettings, that.textEmbeddingSettings)
340370
&& Objects.equals(url, that.url)
341371
&& Objects.equals(headers, that.headers)
342372
&& Objects.equals(queryParameters, that.queryParameters)
@@ -349,9 +379,7 @@ public boolean equals(Object o) {
349379
@Override
350380
public int hashCode() {
351381
return Objects.hash(
352-
similarity,
353-
dimensions,
354-
maxInputTokens,
382+
textEmbeddingSettings,
355383
url,
356384
headers,
357385
queryParameters,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.apache.http.entity.StringEntity;
1515
import org.elasticsearch.common.Strings;
1616
import org.elasticsearch.common.settings.SecureString;
17+
import org.elasticsearch.inference.TaskType;
1718
import org.elasticsearch.xcontent.XContentType;
1819
import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor;
1920
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -58,7 +59,7 @@ public CustomRequest(String query, List<String> input, CustomModel model) {
5859
jsonParams.put(QUERY, toJson(query, QUERY));
5960
}
6061

61-
jsonParams.put(INPUT, toJson(input, INPUT));
62+
addInputJsonParam(jsonParams, input, model.getTaskType());
6263

6364
jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}");
6465
stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}");
@@ -83,6 +84,14 @@ private static void addJsonStringParams(Map<String, String> jsonStringParams, Ma
8384
}
8485
}
8586

87+
private static void addInputJsonParam(Map<String, String> jsonParams, List<String> input, TaskType taskType) {
88+
if (taskType == TaskType.COMPLETION && input.isEmpty() == false) {
89+
jsonParams.put(INPUT, toJson(input.get(0), INPUT));
90+
} else {
91+
jsonParams.put(INPUT, toJson(input, INPUT));
92+
}
93+
}
94+
8695
private URI buildUri() {
8796
var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL);
8897

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.http.HttpHeaders;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1213
import org.elasticsearch.inference.SimilarityMeasure;
1314
import org.elasticsearch.inference.TaskType;
1415
import org.elasticsearch.test.ESTestCase;
@@ -107,9 +108,12 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
107108
String requestContentString = "\"input\":\"${input}\"";
108109

109110
CustomServiceSettings serviceSettings = new CustomServiceSettings(
110-
SimilarityMeasure.DOT_PRODUCT,
111-
dims,
112-
maxInputTokens,
111+
new CustomServiceSettings.TextEmbeddingSettings(
112+
SimilarityMeasure.DOT_PRODUCT,
113+
dims,
114+
maxInputTokens,
115+
DenseVectorFieldMapper.ElementType.FLOAT
116+
),
113117
url,
114118
headers,
115119
QueryParameters.EMPTY,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
5656
""";
5757

5858
var serviceSettings = new CustomServiceSettings(
59-
null,
60-
null,
61-
null,
59+
CustomServiceSettings.TextEmbeddingSettings.EMPTY,
6260
"${url}",
6361
null,
6462
null,

0 commit comments

Comments
 (0)