Skip to content

Commit a530f02

Browse files
Task type and base inference action tests
1 parent 6bf3fcd commit a530f02

File tree

6 files changed

+158
-5
lines changed

6 files changed

+158
-5
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ void unifiedCompletionInfer(
127127
);
128128

129129
/**
130-
* Chunk long text according to {@code chunkingOptions} or the
131-
* model defaults if {@code chunkingOptions} contains unset
132-
* values.
130+
* Chunk long text.
133131
*
134132
* @param model The model
135133
* @param query Inference query, mainly for re-ranking

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ public static TaskType fromString(String name) {
3838
}
3939

4040
public static TaskType fromStringOrStatusException(String name) {
41+
if (name == null) {
42+
throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST);
43+
}
44+
4145
try {
4246
TaskType taskType = TaskType.fromString(name);
4347
return Objects.requireNonNull(taskType);

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
6363
PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages"));
6464
PARSER.declareString(optionalConstructorArg(), new ParseField("model"));
6565
PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens"));
66-
// PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"),
67-
// ObjectParser.ValueType.VALUE_ARRAY);
6866
PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop"));
6967
PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature"));
7068
PARSER.declareField(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.hamcrest.Matchers;
14+
15+
public class TaskTypeTests extends ESTestCase {
16+
17+
public void testFromStringOrStatusException() {
18+
var exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException(null));
19+
assertThat(exception.getMessage(), Matchers.is("Task type must not be null"));
20+
21+
exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah"));
22+
assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]"));
23+
24+
assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY));
25+
}
26+
27+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
package org.elasticsearch.xpack.inference.rest;
99

1010
import org.apache.lucene.util.SetOnce;
11+
import org.elasticsearch.ElasticsearchStatusException;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.common.bytes.BytesArray;
1314
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.inference.TaskType;
1416
import org.elasticsearch.rest.RestChannel;
1517
import org.elasticsearch.rest.RestRequest;
18+
import org.elasticsearch.rest.RestRequestTests;
1619
import org.elasticsearch.rest.action.RestChunkedToXContentListener;
1720
import org.elasticsearch.test.rest.FakeRestRequest;
1821
import org.elasticsearch.test.rest.RestActionTestCase;
@@ -26,6 +29,10 @@
2629
import java.util.Map;
2730

2831
import static org.elasticsearch.rest.RestRequest.Method.POST;
32+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams;
33+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
34+
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
35+
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
2936
import static org.hamcrest.CoreMatchers.is;
3037
import static org.hamcrest.Matchers.equalTo;
3138
import static org.hamcrest.Matchers.instanceOf;
@@ -56,6 +63,42 @@ private static String route(String param) {
5663
return "_route/" + param;
5764
}
5865

66+
public void testParseParams_ExtractsInferenceIdAndTaskType() {
67+
var params = parseParams(
68+
RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString()))
69+
);
70+
assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION)));
71+
}
72+
73+
public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() {
74+
var params = parseParams(
75+
RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString()))
76+
);
77+
assertThat(params, is(new BaseInferenceAction.Params(TASK_TYPE_OR_INFERENCE_ID, TaskType.ANY)));
78+
}
79+
80+
public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() {
81+
var e = expectThrows(
82+
ElasticsearchStatusException.class,
83+
() -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id")))
84+
);
85+
assertThat(e.getMessage(), is("Task type must not be null"));
86+
}
87+
88+
public void testParseTimeout_ReturnsTimeout() {
89+
var timeout = parseTimeout(
90+
RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s"))
91+
);
92+
93+
assertThat(timeout, is(TimeValue.timeValueSeconds(4)));
94+
}
95+
96+
public void testParseTimeout_ReturnsDefaultTimeout() {
97+
var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of()));
98+
99+
assertThat(timeout, is(TimeValue.timeValueSeconds(30)));
100+
}
101+
59102
public void testUsesDefaultTimeout() {
60103
SetOnce<Boolean> executeCalled = new SetOnce<>();
61104
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rest;
9+
10+
import org.apache.lucene.util.SetOnce;
11+
import org.elasticsearch.common.bytes.BytesArray;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
14+
import org.elasticsearch.rest.AbstractRestChannel;
15+
import org.elasticsearch.rest.RestChannel;
16+
import org.elasticsearch.rest.RestRequest;
17+
import org.elasticsearch.rest.RestResponse;
18+
import org.elasticsearch.test.rest.FakeRestRequest;
19+
import org.elasticsearch.test.rest.RestActionTestCase;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
22+
import org.junit.Before;
23+
24+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
25+
import static org.hamcrest.CoreMatchers.is;
26+
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.instanceOf;
28+
29+
public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase {
30+
31+
@Before
32+
public void setUpAction() {
33+
controller().registerHandler(new RestUnifiedCompletionInferenceAction());
34+
}
35+
36+
public void testStreamIsTrue() {
37+
SetOnce<Boolean> executeCalled = new SetOnce<>();
38+
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
39+
assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class));
40+
41+
var request = (UnifiedCompletionAction.Request) actionRequest;
42+
assertThat(request.isStreaming(), is(true));
43+
44+
executeCalled.set(true);
45+
return createResponse();
46+
}));
47+
48+
var requestBody = """
49+
{
50+
"messages": [
51+
{
52+
"content": "abc",
53+
"role": "user"
54+
}
55+
]
56+
}
57+
""";
58+
59+
RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
60+
.withPath("_inference/completion/test/_unified")
61+
.withContent(new BytesArray(requestBody), XContentType.JSON)
62+
.build();
63+
64+
final SetOnce<RestResponse> responseSetOnce = new SetOnce<>();
65+
dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) {
66+
@Override
67+
public void sendResponse(RestResponse response) {
68+
responseSetOnce.set(response);
69+
}
70+
});
71+
72+
// the response content will be null when there is no error
73+
assertNull(responseSetOnce.get().content());
74+
// var responseBody = responseSetOnce.get().content().utf8ToString();
75+
// assertThat(Objects.requireNonNull(responseSetOnce.get().content()).utf8ToString(), equalTo(createResponse()));
76+
assertThat(executeCalled.get(), equalTo(true));
77+
}
78+
79+
private void dispatchRequest(final RestRequest request, final RestChannel channel) {
80+
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
81+
controller().dispatchRequest(request, channel, threadContext);
82+
}
83+
}

0 commit comments

Comments
 (0)