Skip to content

Commit f382246

Browse files
Fixing various tests
1 parent fa415d8 commit f382246

24 files changed

+145
-106
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
8181
new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new),
8282
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new),
8383
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new)
84-
// new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new),
85-
// new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new)
8684
);
8785
}
8886

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,20 @@
1212

1313
public class ChatCompletionInput extends InferenceInputs {
1414
private final List<String> input;
15-
private final boolean stream;
1615

1716
public ChatCompletionInput(List<String> input) {
1817
this(input, false);
1918
}
2019

2120
public ChatCompletionInput(List<String> input, boolean stream) {
22-
super();
21+
super(stream);
2322
this.input = Objects.requireNonNull(input);
24-
this.stream = stream;
2523
}
2624

2725
public List<String> getInputs() {
2826
return this.input;
2927
}
3028

31-
public boolean stream() {
32-
return stream;
33-
}
34-
3529
public int inputSize() {
3630
return input.size();
3731
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,20 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) {
2121
}
2222

2323
private final List<String> input;
24-
private final boolean stream;
2524

2625
public DocumentsOnlyInput(List<String> input) {
2726
this(input, false);
2827
}
2928

3029
public DocumentsOnlyInput(List<String> input, boolean stream) {
31-
super();
30+
super(stream);
3231
this.input = Objects.requireNonNull(input);
33-
this.stream = stream;
3432
}
3533

3634
public List<String> getInputs() {
3735
return this.input;
3836
}
3937

40-
public boolean stream() {
41-
return stream;
42-
}
43-
4438
public int inputSize() {
4539
return input.size();
4640
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
import org.elasticsearch.common.Strings;
1111

1212
public abstract class InferenceInputs {
13+
private final boolean stream;
14+
15+
public InferenceInputs(boolean stream) {
16+
this.stream = stream;
17+
}
18+
1319
public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class<?> clazz) {
1420
return new IllegalArgumentException(
1521
Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz)
@@ -24,5 +30,9 @@ public <T> T castTo(Class<T> clazz) {
2430
return clazz.cast(this);
2531
}
2632

33+
public boolean stream() {
34+
return stream;
35+
}
36+
2737
public abstract int inputSize();
2838
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2222

2323
private final String query;
2424
private final List<String> chunks;
25-
private final boolean stream;
2625

2726
public QueryAndDocsInputs(String query, List<String> chunks) {
2827
this(query, chunks, false);
2928
}
3029

3130
public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
32-
super();
31+
super(stream);
3332
this.query = Objects.requireNonNull(query);
3433
this.chunks = Objects.requireNonNull(chunks);
35-
this.stream = stream;
3634
}
3735

3836
public String getQuery() {
@@ -43,10 +41,6 @@ public List<String> getChunks() {
4341
return chunks;
4442
}
4543

46-
public boolean stream() {
47-
return stream;
48-
}
49-
5044
public int inputSize() {
5145
return chunks.size();
5246
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
public class UnifiedChatInput extends InferenceInputs {
1616
private final UnifiedCompletionRequest request;
17-
private final boolean stream;
1817

1918
public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) {
19+
super(stream);
2020
this.request = Objects.requireNonNull(request);
21-
this.stream = stream;
2221
}
2322

2423
public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) {
@@ -47,10 +46,6 @@ public UnifiedCompletionRequest getRequest() {
4746
return request;
4847
}
4948

50-
public boolean stream() {
51-
return stream;
52-
}
53-
5449
public int inputSize() {
5550
return request.messages().size();
5651
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private static InferenceInputs createInput(Model model, List<String> input, @Nul
7575
return switch (model.getTaskType()) {
7676
case COMPLETION -> new ChatCompletionInput(input, stream);
7777
case RERANK -> new QueryAndDocsInputs(query, input, stream);
78-
case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream);
78+
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
7979
default -> throw new ElasticsearchStatusException(
8080
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
8181
RestStatus.BAD_REQUEST

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,8 @@ protected void doInfer(
284284
) {
285285
if (model instanceof GoogleAiStudioCompletionModel completionModel) {
286286
var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool());
287-
var docsOnly = DocumentsOnlyInput.of(inputs);
288287
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
289-
completionModel.uri(docsOnly.stream()),
288+
completionModel.uri(inputs.stream()),
290289
"Google AI Studio completion"
291290
);
292291
var action = new SingleInputSenderExecutableAction(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.inference.ModelConfigurations;
2020
import org.elasticsearch.inference.ModelSecrets;
2121
import org.elasticsearch.inference.SimilarityMeasure;
22+
import org.elasticsearch.inference.TaskType;
2223
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
2324
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2425
import org.elasticsearch.xpack.inference.common.Truncator;
@@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName
160161
var mockConfigs = mock(ModelConfigurations.class);
161162
when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId);
162163
when(mockConfigs.getService()).thenReturn(serviceName);
164+
when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
163165

164166
var mockModel = mock(Model.class);
165167
when(mockModel.getConfigurations()).thenReturn(mockConfigs);
168+
when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
166169

167170
return mockModel;
168171
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,11 @@ public void testOneInputIsValid() {
6161
assertTrue("Test failed to call listener.", testRan.get());
6262
}
6363

64-
public void testInvalidInputType() {
65-
var badInput = mock(InferenceInputs.class);
66-
var actualException = new AtomicReference<Exception>();
67-
68-
executableAction.execute(
69-
badInput,
70-
mock(TimeValue.class),
71-
ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set)
72-
);
73-
74-
assertThat(actualException.get(), notNullValue());
75-
assertThat(actualException.get().getMessage(), is("Invalid inference input type"));
76-
assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class));
77-
assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR));
78-
}
79-
8064
public void testMoreThanOneInput() {
8165
var badInput = mock(DocumentsOnlyInput.class);
82-
when(badInput.getInputs()).thenReturn(List.of("one", "two"));
66+
var input = List.of("one", "two");
67+
when(badInput.getInputs()).thenReturn(input);
68+
when(badInput.inputSize()).thenReturn(input.size());
8369
var actualException = new AtomicReference<Exception>();
8470

8571
executableAction.execute(

0 commit comments

Comments
 (0)