Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.TruncatingRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler;
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;

import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
Expand All @@ -19,14 +19,10 @@
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class OpenAiChatCompletionResponseEntity {

private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response";

/**
* Parses the OpenAI chat completion response.
* For a request like:
Expand Down Expand Up @@ -71,32 +67,51 @@ public class OpenAiChatCompletionResponseEntity {
*/

public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);

positionParserAtTokenAfterField(jsonParser, "choices", FAILED_TO_FIND_FIELD_TEMPLATE);
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The major changes here are to use the ConstructingObjectParser instead of iterating by token.

return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
}
}

jsonParser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);
public record CompletionResult(List<Choice> choices) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<CompletionResult, Void> PARSER = new ConstructingObjectParser<>(
CompletionResult.class.getSimpleName(),
true,
args -> new CompletionResult((List<Choice>) args[0])
);

positionParserAtTokenAfterField(jsonParser, "message", FAILED_TO_FIND_FIELD_TEMPLATE);
static {
PARSER.declareObjectArray(constructorArg(), Choice.PARSER::apply, new ParseField("choices"));
}

token = jsonParser.currentToken();
public ChatCompletionResults toChatCompletionResults() {
return new ChatCompletionResults(
choices.stream().map(choice -> new ChatCompletionResults.Result(choice.message.content)).toList()
);
}
}

ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
public record Choice(Message message) {
public static final ConstructingObjectParser<Choice, Void> PARSER = new ConstructingObjectParser<>(
Choice.class.getSimpleName(),
true,
args -> new Choice((Message) args[0])
);

positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE);
static {
PARSER.declareObject(constructorArg(), Message.PARSER::apply, new ParseField("message"));
}
}

XContentParser.Token contentToken = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser);
String content = jsonParser.text();
public record Message(String content) {
public static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Message.class.getSimpleName(),
true,
args -> new Message((String) args[0])
);

return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
static {
PARSER.declareString(constructorArg(), new ParseField("content"));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
Expand All @@ -21,8 +21,8 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.createOrgHeader;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;

public class OpenAiEmbeddingsRequest implements Request {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class was moved and transitioned to use the ConstructingObjectParser.

* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.openai;

import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class OpenAiEmbeddingsResponseEntity {
/**
* Parses the OpenAI json response.
* For a request like:
*
* <pre>
* <code>
* {
* "inputs": ["hello this is my name", "I wish I was there!"]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "object": "list",
* "data": [
* {
* "object": "embedding",
* "embedding": [
* -0.009327292,
* .... (1536 floats total for ada-002)
* -0.0028842222,
* ],
* "index": 0
* },
* {
* "object": "embedding",
* "embedding": [ ... ],
* "index": 1
* }
* ],
* "model": "text-embedding-ada-002",
* "usage": {
* "prompt_tokens": 8,
* "total_tokens": 8
* }
* }
* </code>
* </pre>
*/
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
}
}

public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResult.class.getSimpleName(),
true,
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
);

static {
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
}

public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
return new TextEmbeddingFloatResults(
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}

public record EmbeddingFloatResultEntry(List<Float> embedding) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResultEntry, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResultEntry.class.getSimpleName(),
true,
args -> new EmbeddingFloatResultEntry((List<Float>) args[0])
);

static {
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
}
}

private OpenAiEmbeddingsResponseEntity() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.elasticsearch.xpack.inference.external.request.Request;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
Expand All @@ -21,8 +21,8 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.createOrgHeader;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;

public class OpenAiUnifiedChatCompletionRequest implements Request {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.openai;
package org.elasticsearch.xpack.inference.external.openai;

import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;

import java.io.IOException;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.openai.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;

import java.io.IOException;

Expand Down
Loading