@@ -50,11 +50,13 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
5050 private ModelRegistry modelRegistry ;
5151 private final MockWebServer webServer = new MockWebServer ();
5252 private ThreadPool threadPool ;
53+ private String gatewayUrl ;
5354
5455 @ Before
5556 public void createComponents () throws Exception {
5657 threadPool = createThreadPool (inferenceUtilityPool ());
5758 webServer .start ();
59+ gatewayUrl = getUrl (webServer );
5860 modelRegistry = new ModelRegistry (client ());
5961 }
6062
@@ -64,16 +66,17 @@ public void shutdown() {
6466 webServer .close ();
6567 }
6668
69+ @ Override
70+ protected boolean resetNodeAfterTest () {
71+ return true ;
72+ }
73+
6774 @ Override
6875 protected Collection <Class <? extends Plugin >> getPlugins () {
6976 return pluginList (ReindexPlugin .class );
7077 }
7178
7279 public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect () throws Exception {
73- var clientManager = HttpClientManager .create (Settings .EMPTY , threadPool , mockClusterServiceEmpty (), mock (ThrottlerManager .class ));
74- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
75- var gatewayUrl = getUrl (webServer );
76-
7780 String responseJson = """
7881 {
7982 "models": [
@@ -87,15 +90,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
8790
8891 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
8992
90- try (
91- var service = new ElasticInferenceService (
92- senderFactory ,
93- createWithEmptySettings (threadPool ),
94- new ElasticInferenceServiceComponents (gatewayUrl ),
95- modelRegistry ,
96- new ElasticInferenceServiceAuthorizationHandler (gatewayUrl , threadPool )
97- )
98- ) {
93+ try (var service = createElasticInferenceService ()) {
9994 service .waitForAuthorizationToComplete (TIMEOUT );
10095 assertThat (service .supportedStreamingTasks (), is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .ANY )));
10196 assertThat (
@@ -115,17 +110,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
115110 }
116111
117112 public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty () throws Exception {
118- var gatewayUrl = getUrl (webServer );
119-
120113 {
121- var clientManager = HttpClientManager .create (
122- Settings .EMPTY ,
123- threadPool ,
124- mockClusterServiceEmpty (),
125- mock (ThrottlerManager .class )
126- );
127- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
128-
129114 String responseJson = """
130115 {
131116 "models": [
@@ -139,15 +124,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
139124
140125 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
141126
142- try (
143- var service = new ElasticInferenceService (
144- senderFactory ,
145- createWithEmptySettings (threadPool ),
146- new ElasticInferenceServiceComponents (gatewayUrl ),
147- modelRegistry ,
148- new ElasticInferenceServiceAuthorizationHandler (gatewayUrl , threadPool )
149- )
150- ) {
127+ try (var service = createElasticInferenceService ()) {
151128 service .waitForAuthorizationToComplete (TIMEOUT );
152129 assertThat (service .supportedStreamingTasks (), is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .ANY )));
153130 assertThat (
@@ -186,18 +163,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
186163
187164 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (noAuthorizationResponseJson ));
188165
189- var httpManager = HttpClientManager .create (Settings .EMPTY , threadPool , mockClusterServiceEmpty (), mock (ThrottlerManager .class ));
190- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , httpManager );
191-
192- try (
193- var service = new ElasticInferenceService (
194- senderFactory ,
195- createWithEmptySettings (threadPool ),
196- new ElasticInferenceServiceComponents (gatewayUrl ),
197- modelRegistry ,
198- new ElasticInferenceServiceAuthorizationHandler (gatewayUrl , threadPool )
199- )
200- ) {
166+ try (var service = createElasticInferenceService ()) {
201167 service .waitForAuthorizationToComplete (TIMEOUT );
202168 assertThat (service .supportedStreamingTasks (), is (EnumSet .noneOf (TaskType .class )));
203169 assertTrue (service .defaultConfigIds ().isEmpty ());
@@ -211,4 +177,95 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
211177 }
212178 }
213179 }
180+
181+ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt () throws Exception {
182+ {
183+ String responseJson = """
184+ {
185+ "models": [
186+ {
187+ "model_name": "rainbow-sprinkles",
188+ "task_types": ["chat"]
189+ },
190+ {
191+ "model_name": "elser-v2",
192+ "task_types": ["embed/text/sparse"]
193+ }
194+ ]
195+ }
196+ """ ;
197+
198+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
199+
200+ try (var service = createElasticInferenceService ()) {
201+ service .waitForAuthorizationToComplete (TIMEOUT );
202+ assertThat (service .supportedStreamingTasks (), is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .ANY )));
203+ assertThat (
204+ service .defaultConfigIds (),
205+ is (
206+ List .of (
207+ new InferenceService .DefaultConfigId (
208+ ".rainbow-sprinkles-elastic" ,
209+ MinimalServiceSettings .chatCompletion (),
210+ service
211+ )
212+ )
213+ )
214+ );
215+ assertThat (service .supportedTaskTypes (), is (EnumSet .of (TaskType .CHAT_COMPLETION , TaskType .SPARSE_EMBEDDING )));
216+
217+ PlainActionFuture <List <Model >> listener = new PlainActionFuture <>();
218+ service .defaultConfigs (listener );
219+ assertThat (listener .actionGet (TIMEOUT ).get (0 ).getConfigurations ().getInferenceEntityId (), is (".rainbow-sprinkles-elastic" ));
220+
221+ var getModelListener = new PlainActionFuture <UnparsedModel >();
222+ // persists the default endpoints
223+ modelRegistry .getModel (".rainbow-sprinkles-elastic" , getModelListener );
224+
225+ var inferenceEntity = getModelListener .actionGet (TIMEOUT );
226+ assertThat (inferenceEntity .inferenceEntityId (), is (".rainbow-sprinkles-elastic" ));
227+ assertThat (inferenceEntity .taskType (), is (TaskType .CHAT_COMPLETION ));
228+ }
229+ }
230+ {
231+ String noAuthorizationResponseJson = """
232+ {
233+ "models": [
234+ {
235+ "model_name": "elser-v2",
236+ "task_types": ["embed/text/sparse"]
237+ }
238+ ]
239+ }
240+ """ ;
241+
242+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (noAuthorizationResponseJson ));
243+
244+ try (var service = createElasticInferenceService ()) {
245+ service .waitForAuthorizationToComplete (TIMEOUT );
246+ assertThat (service .supportedStreamingTasks (), is (EnumSet .noneOf (TaskType .class )));
247+ assertTrue (service .defaultConfigIds ().isEmpty ());
248+ assertThat (service .supportedTaskTypes (), is (EnumSet .of (TaskType .SPARSE_EMBEDDING )));
249+
250+ var getModelListener = new PlainActionFuture <UnparsedModel >();
251+ modelRegistry .getModel (".rainbow-sprinkles-elastic" , getModelListener );
252+
253+ var exception = expectThrows (ResourceNotFoundException .class , () -> getModelListener .actionGet (TIMEOUT ));
254+ assertThat (exception .getMessage (), is ("Inference endpoint not found [.rainbow-sprinkles-elastic]" ));
255+ }
256+ }
257+ }
258+
259+ private ElasticInferenceService createElasticInferenceService () {
260+ var httpManager = HttpClientManager .create (Settings .EMPTY , threadPool , mockClusterServiceEmpty (), mock (ThrottlerManager .class ));
261+ var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , httpManager );
262+
263+ return new ElasticInferenceService (
264+ senderFactory ,
265+ createWithEmptySettings (threadPool ),
266+ new ElasticInferenceServiceComponents (gatewayUrl ),
267+ modelRegistry ,
268+ new ElasticInferenceServiceAuthorizationHandler (gatewayUrl , threadPool )
269+ );
270+ }
214271}
0 commit comments