Skip to content

Commit 10ac1ae

Browse files
Adding tests for UnifiedCompletionAction Request
1 parent f983d6a commit 10ac1ae

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ protected InferenceAction.Request createTestInstance() {
4141
return new InferenceAction.Request(
4242
randomFrom(TaskType.values()),
4343
randomAlphaOfLength(6),
44-
// null,
4544
randomAlphaOfLengthOrNull(10),
4645
randomList(1, 5, () -> randomAlphaOfLength(8)),
4746
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.inference.UnifiedCompletionRequest;
16+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
21+
import static org.hamcrest.Matchers.is;
22+
23+
public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase<UnifiedCompletionAction.Request> {
24+
25+
public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage_Is_Null() {
26+
var request = new UnifiedCompletionAction.Request(
27+
"inference_id",
28+
TaskType.COMPLETION,
29+
UnifiedCompletionRequest.of(null),
30+
TimeValue.timeValueSeconds(10)
31+
);
32+
var exception = request.validate();
33+
assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be null;"));
34+
}
35+
36+
public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_EmptyArray() {
37+
var request = new UnifiedCompletionAction.Request(
38+
"inference_id",
39+
TaskType.COMPLETION,
40+
UnifiedCompletionRequest.of(List.of()),
41+
TimeValue.timeValueSeconds(10)
42+
);
43+
var exception = request.validate();
44+
assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be an empty array;"));
45+
}
46+
47+
public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() {
48+
var request = new UnifiedCompletionAction.Request(
49+
"inference_id",
50+
TaskType.SPARSE_EMBEDDING,
51+
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
52+
TimeValue.timeValueSeconds(10)
53+
);
54+
var exception = request.validate();
55+
assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];"));
56+
}
57+
58+
@Override
59+
protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) {
60+
return instance;
61+
}
62+
63+
@Override
64+
protected Writeable.Reader<UnifiedCompletionAction.Request> instanceReader() {
65+
return UnifiedCompletionAction.Request::new;
66+
}
67+
68+
@Override
69+
protected UnifiedCompletionAction.Request createTestInstance() {
70+
return new UnifiedCompletionAction.Request(
71+
randomAlphaOfLength(10),
72+
randomFrom(TaskType.values()),
73+
UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(),
74+
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
75+
);
76+
}
77+
78+
@Override
79+
protected UnifiedCompletionAction.Request mutateInstance(UnifiedCompletionAction.Request instance) throws IOException {
80+
return randomValueOtherThan(instance, this::createTestInstance);
81+
}
82+
83+
@Override
84+
protected NamedWriteableRegistry getNamedWriteableRegistry() {
85+
return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables());
86+
}
87+
}

0 commit comments

Comments
 (0)