Skip to content

Commit 059114f

Browse files
Cleaning up comments, updating validation input type, and moving model deployment starting to model validator
1 parent 4274a40 commit 059114f

File tree

10 files changed

+253
-10
lines changed

10 files changed

+253
-10
lines changed

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ public void infer(
131131
)
132132
);
133133
} else {
134+
// Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to
135+
// pass. This is required to test that streaming fails for a sparse_embedding endpoint.
134136
listener.onResponse(makeTextEmbeddingResults(input));
135137
}
136138
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ private void parseAndStoreModel(
207207
delegate.onFailure(e);
208208
}
209209
}
210-
)
210+
),
211+
timeout
211212
)
212213
);
213214

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.core.Nullable;
1515
import org.elasticsearch.core.Strings;
1616
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.inference.InputType;
1718
import org.elasticsearch.inference.Model;
1819
import org.elasticsearch.inference.SimilarityMeasure;
1920
import org.elasticsearch.inference.TaskType;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ protected void maybeStartDeployment(
298298
InferModelAction.Request request,
299299
ActionListener<InferModelAction.Response> listener
300300
) {
301-
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
301+
if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
302302
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
303303
client.execute(InferModelAction.INSTANCE, request, listener);
304304
}));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
5959
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
6060
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
61+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
6162

6263
import java.util.ArrayList;
6364
import java.util.Collections;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
package org.elasticsearch.xpack.inference.services.validation;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.action.ActionListener;
1112
import org.elasticsearch.core.TimeValue;
1213
import org.elasticsearch.inference.InferenceService;
1314
import org.elasticsearch.inference.Model;
15+
import org.elasticsearch.rest.RestStatus;
1416

1517
public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
1618

@@ -22,9 +24,36 @@ public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator)
2224

2325
@Override
2426
public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener<Model> listener) {
25-
modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> {
26-
// TODO: Cleanup the below code
27-
service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception)));
28-
}));
27+
service.start(model, timeout, ActionListener.wrap((modelDeploymentStarted) -> {
28+
if (modelDeploymentStarted) {
29+
try {
30+
modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> {
31+
stopModelDeployment(service, model, l, exception);
32+
}));
33+
} catch (Exception e) {
34+
stopModelDeployment(service, model, listener, e);
35+
}
36+
} else {
37+
listener.onFailure(
38+
new ElasticsearchStatusException("Could not deploy model for inference endpoint", RestStatus.INTERNAL_SERVER_ERROR)
39+
);
40+
}
41+
}, listener::onFailure));
42+
}
43+
44+
private void stopModelDeployment(InferenceService service, Model model, ActionListener<Model> listener, Exception e) {
45+
service.stop(
46+
model,
47+
ActionListener.wrap(
48+
(v) -> listener.onFailure(e),
49+
(ex) -> listener.onFailure(
50+
new ElasticsearchStatusException(
51+
"Model validation failed and model deployment could not be stopped",
52+
RestStatus.INTERNAL_SERVER_ERROR,
53+
ex
54+
)
55+
)
56+
)
57+
);
2958
}
3059
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A
3535
TEST_INPUT,
3636
false,
3737
Map.of(),
38-
InputType.INGEST,
38+
InputType.INTERNAL_INGEST,
3939
timeout,
4040
ActionListener.wrap(r -> {
4141
if (r != null) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.inference.ChunkedInference;
2121
import org.elasticsearch.inference.ChunkingSettings;
2222
import org.elasticsearch.inference.InferenceServiceConfiguration;
23+
import org.elasticsearch.inference.InferenceServiceResults;
2324
import org.elasticsearch.inference.InputType;
2425
import org.elasticsearch.inference.Model;
2526
import org.elasticsearch.inference.ModelConfigurations;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.validation;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.inference.InferenceService;
14+
import org.elasticsearch.inference.Model;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.test.ESTestCase;
17+
import org.junit.Before;
18+
import org.mockito.ArgumentCaptor;
19+
import org.mockito.Mock;
20+
21+
import static org.mockito.ArgumentMatchers.any;
22+
import static org.mockito.ArgumentMatchers.eq;
23+
import static org.mockito.Mockito.doAnswer;
24+
import static org.mockito.Mockito.doThrow;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.verifyNoMoreInteractions;
27+
import static org.mockito.Mockito.when;
28+
import static org.mockito.MockitoAnnotations.openMocks;
29+
30+
public class ElasticsearchInternalServiceModelValidatorTests extends ESTestCase {
31+
32+
private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE;
33+
private static final String MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE =
34+
"Model validation failed and model deployment could not be stopped";
35+
36+
@Mock
37+
private ModelValidator mockModelValidator;
38+
@Mock
39+
private InferenceService mockInferenceService;
40+
@Mock
41+
private Model mockModel;
42+
@Mock
43+
private ActionListener<Model> mockActionListener;
44+
45+
private ElasticsearchInternalServiceModelValidator underTest;
46+
47+
@Before
48+
public void setup() {
49+
openMocks(this);
50+
51+
underTest = new ElasticsearchInternalServiceModelValidator(mockModelValidator);
52+
53+
when(mockActionListener.delegateResponse(any())).thenCallRealMethod();
54+
}
55+
56+
public void testValidate_ModelDeploymentThrowsException() {
57+
doThrow(ElasticsearchStatusException.class).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
58+
59+
assertThrows(
60+
ElasticsearchStatusException.class,
61+
() -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); }
62+
);
63+
64+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
65+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
66+
}
67+
68+
public void testValidate_ModelDeploymentReturnsFalse() {
69+
mockModelDeployment(false);
70+
71+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
72+
73+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
74+
verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
75+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
76+
}
77+
78+
public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsStopped() {
79+
mockModelDeployment(true);
80+
doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator)
81+
.validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
82+
mockModelStop(true);
83+
84+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
85+
86+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
87+
verify(mockInferenceService).stop(eq(mockModel), any());
88+
verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
89+
verify(mockActionListener).delegateResponse(any());
90+
verifyMockActionListenerAfterStopModelDeployment(true);
91+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
92+
}
93+
94+
public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsNotStopped() {
95+
mockModelDeployment(true);
96+
doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator)
97+
.validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
98+
mockModelStop(false);
99+
100+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
101+
102+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
103+
verify(mockInferenceService).stop(eq(mockModel), any());
104+
verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
105+
verify(mockActionListener).delegateResponse(any());
106+
verifyMockActionListenerAfterStopModelDeployment(false);
107+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
108+
}
109+
110+
public void testValidate_ModelValidationFailsAndModelDeploymentIsStopped() {
111+
mockModelDeployment(true);
112+
doAnswer(ans -> {
113+
ActionListener<Model> responseListener = ans.getArgument(3);
114+
responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR));
115+
return null;
116+
}).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
117+
mockModelStop(true);
118+
119+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
120+
121+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
122+
verify(mockInferenceService).stop(eq(mockModel), any());
123+
verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
124+
verify(mockActionListener).delegateResponse(any());
125+
verifyMockActionListenerAfterStopModelDeployment(true);
126+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
127+
}
128+
129+
public void testValidate_ModelValidationFailsAndModelDeploymentIsNotStopped() {
130+
mockModelDeployment(true);
131+
doAnswer(ans -> {
132+
ActionListener<Model> responseListener = ans.getArgument(3);
133+
responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR));
134+
return null;
135+
}).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
136+
mockModelStop(false);
137+
138+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
139+
140+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
141+
verify(mockInferenceService).stop(eq(mockModel), any());
142+
verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
143+
verify(mockActionListener).delegateResponse(any());
144+
verifyMockActionListenerAfterStopModelDeployment(false);
145+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
146+
}
147+
148+
public void testValidate_ModelValidationSucceeds() {
149+
mockModelDeployment(true);
150+
mockModelStop(true);
151+
152+
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
153+
154+
verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
155+
verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
156+
verify(mockActionListener).delegateResponse(any());
157+
verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
158+
}
159+
160+
private void mockModelDeployment(boolean modelDeploymentStarted) {
161+
doAnswer(ans -> {
162+
ActionListener<Boolean> responseListener = ans.getArgument(2);
163+
responseListener.onResponse(modelDeploymentStarted);
164+
return null;
165+
}).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
166+
}
167+
168+
private void mockModelStop(boolean modelDeploymentStopped) {
169+
if (modelDeploymentStopped) {
170+
doAnswer(ans -> {
171+
ActionListener<Void> responseListener = ans.getArgument(1);
172+
responseListener.onResponse(null);
173+
return null;
174+
}).when(mockInferenceService).stop(eq(mockModel), any());
175+
} else {
176+
doAnswer(ans -> {
177+
ActionListener<Void> responseListener = ans.getArgument(1);
178+
responseListener.onFailure(new ElasticsearchStatusException("Model stop failed", RestStatus.INTERNAL_SERVER_ERROR));
179+
return null;
180+
}).when(mockInferenceService).stop(eq(mockModel), any());
181+
}
182+
}
183+
184+
private void verifyMockActionListenerAfterStopModelDeployment(boolean modelDeploymentStopped) {
185+
verify(mockInferenceService).stop(eq(mockModel), any());
186+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
187+
verify(mockActionListener).onFailure(exceptionCaptor.capture());
188+
assertTrue(exceptionCaptor.getValue() instanceof ElasticsearchStatusException);
189+
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((ElasticsearchStatusException) exceptionCaptor.getValue()).status());
190+
191+
if (modelDeploymentStopped) {
192+
assertFalse(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE));
193+
} else {
194+
assertTrue(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE));
195+
}
196+
}
197+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void testValidate_ServiceThrowsException() {
6969
eq(TEST_INPUT),
7070
eq(false),
7171
eq(Map.of()),
72-
eq(InputType.INGEST),
72+
eq(InputType.INTERNAL_INGEST),
7373
eq(TIMEOUT),
7474
any()
7575
);
@@ -104,7 +104,18 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r
104104
responseListener.onResponse(result);
105105
return null;
106106
}).when(mockInferenceService)
107-
.infer(eq(mockModel), eq(query), eq(TEST_INPUT), eq(false), eq(Map.of()), eq(InputType.INGEST), eq(TIMEOUT), any());
107+
.infer(
108+
eq(mockModel),
109+
eq(query),
110+
eq(null),
111+
eq(null),
112+
eq(TEST_INPUT),
113+
eq(false),
114+
eq(Map.of()),
115+
eq(InputType.INTERNAL_INGEST),
116+
eq(TIMEOUT),
117+
any()
118+
);
108119

109120
underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
110121
}
@@ -119,7 +130,7 @@ private void verifyCallToService(boolean withQuery) {
119130
eq(TEST_INPUT),
120131
eq(false),
121132
eq(Map.of()),
122-
eq(InputType.INGEST),
133+
eq(InputType.INTERNAL_INGEST),
123134
eq(TIMEOUT),
124135
any()
125136
);

0 commit comments

Comments
 (0)