Skip to content

Commit 69c5ff6

Browse files
authored
[Inference API] Support multimodal inputs for chat completion (#142736)
This commit allows "image_url" and "file" to be specified as types of "content" in chat completion requests for ElasticInferenceService. The specification matches the OpenAI specification for multimodal chat completion inputs. Other services will throw an UnsupportedOperationException if they receive a chat completion request with multimodal inputs. Content objects with the type "image_url" are specified as: { "type": "image_url", "image_url": { "url": "base64 encoded image data or URL (if supported)", "detail": "optional detail value" } } Content objects with the type "file" are specified as: { "type": "file", "file": { "file_data": "base64 encoded file data", "filename": "file name" } } Support for the "file_id" field in the "file" object is not added in this commit. When sending the chat completion request to the Elastic Inference Service, the same OpenAI-compatible schema is used as when sending the request to the inference API, so no additional translation logic between the two is needed. Other changes in this commit: - Convert ContentObject class to an abstract class which is extended by ContentObjectText, ContentObjectImage and ContentObjectFile - Modify serialization for ContentObject to throw an exception if attempting to send non-text content to an older version node - Add backward compatibility tests for the above serialization change - Require that "type" is one of "text", "image_url" or "file" for "content" objects. Previously any arbitrary value was allowed. - Update existing tests to use ContentObjectText - Add tests for default behaviour of embedding task and unified chat completion with multimodal inputs to SenderServiceTests
1 parent 8492b41 commit 69c5ff6

File tree

19 files changed

+1144
-117
lines changed

19 files changed

+1144
-117
lines changed

docs/changelog/142736.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
area: Inference
2+
issues: []
3+
pr: 142736
4+
summary: "[Inference API] Support multimodal inputs for chat completion"
5+
type: enhancement

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 427 additions & 22 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9290000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
esql_topn_min_competitive_updates,9289000
1+
inference_api_multimodal_chat_completion,9290000

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,24 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.TransportVersion;
1112
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.core.TimeValue;
1415
import org.elasticsearch.inference.TaskType;
1516
import org.elasticsearch.inference.UnifiedCompletionRequest;
17+
import org.elasticsearch.rest.RestStatus;
1618
import org.elasticsearch.xpack.core.inference.InferenceContext;
1719
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1820

1921
import java.io.IOException;
22+
import java.util.Collection;
2023
import java.util.List;
24+
import java.util.function.Predicate;
2125

26+
import static org.elasticsearch.inference.UnifiedCompletionRequest.MULTIMODAL_CHAT_COMPLETION_SUPPORT_ADDED;
27+
import static org.elasticsearch.test.BWCVersions.DEFAULT_BWC_VERSIONS;
2228
import static org.hamcrest.Matchers.is;
2329

2430
public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase<UnifiedCompletionAction.Request> {
@@ -70,17 +76,53 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() {
7076

7177
@Override
7278
protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) {
79+
InferenceContext context = instance.getContext();
7380
if (version.supports(INFERENCE_CONTEXT) == false) {
74-
return new UnifiedCompletionAction.Request(
75-
instance.getInferenceEntityId(),
76-
instance.getTaskType(),
77-
instance.getUnifiedCompletionRequest(),
78-
InferenceContext.EMPTY_INSTANCE,
79-
instance.getTimeout()
80-
);
81+
context = InferenceContext.EMPTY_INSTANCE;
8182
}
8283

83-
return instance;
84+
return new UnifiedCompletionAction.Request(
85+
instance.getInferenceEntityId(),
86+
instance.getTaskType(),
87+
instance.getUnifiedCompletionRequest(),
88+
context,
89+
instance.getTimeout()
90+
);
91+
}
92+
93+
// Versions before MULTIMODAL_CHAT_COMPLETION_SUPPORT_ADDED throw an exception when serializing non-text content
94+
// Those are tested in testMultimodalContentIsNotBackwardsCompatible
95+
@Override
96+
protected Collection<TransportVersion> bwcVersions() {
97+
return super.bwcVersions().stream().filter(version -> version.supports(MULTIMODAL_CHAT_COMPLETION_SUPPORT_ADDED)).toList();
98+
}
99+
100+
public void testMultimodalContentIsNotBackwardsCompatible() throws IOException {
101+
var unsupportedVersions = DEFAULT_BWC_VERSIONS.stream()
102+
.filter(Predicate.not(version -> version.supports(MULTIMODAL_CHAT_COMPLETION_SUPPORT_ADDED)))
103+
.toList();
104+
for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) {
105+
var testInstance = createTestInstance();
106+
for (var unsupportedVersion : unsupportedVersions) {
107+
if (testInstance.getUnifiedCompletionRequest().containsMultimodalContent()) {
108+
var statusException = assertThrows(
109+
ElasticsearchStatusException.class,
110+
() -> copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), unsupportedVersion)
111+
);
112+
assertThat(statusException.status(), is(RestStatus.BAD_REQUEST));
113+
assertThat(
114+
statusException.getMessage(),
115+
is(
116+
"Cannot send a multimodal chat completion request to an older node. "
117+
+ "Please wait until all nodes are upgraded before using multimodal chat completion inputs"
118+
)
119+
);
120+
} else {
121+
// If the instance doesn't contain multimodal content, assert that it can still be serialized
122+
assertBwcSerialization(testInstance, unsupportedVersion);
123+
}
124+
}
125+
}
84126
}
85127

86128
@Override

0 commit comments

Comments
 (0)