6060import org .elasticsearch .xpack .inference .services .openshiftai .embeddings .OpenShiftAiEmbeddingsServiceSettings ;
6161import org .elasticsearch .xpack .inference .services .settings .DefaultSecretSettings ;
6262import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
63- import org .hamcrest .CoreMatchers ;
6463import org .hamcrest .Matchers ;
6564import org .junit .After ;
6665import org .junit .Before ;
9493import static org .elasticsearch .xpack .inference .services .openshiftai .completion .OpenShiftAiChatCompletionModelTests .createChatCompletionModel ;
9594import static org .elasticsearch .xpack .inference .services .openshiftai .completion .OpenShiftAiChatCompletionServiceSettingsTests .getServiceSettingsMap ;
9695import static org .elasticsearch .xpack .inference .services .settings .DefaultSecretSettingsTests .getSecretSettingsMap ;
97- import static org .hamcrest .CoreMatchers .is ;
9896import static org .hamcrest .Matchers .equalTo ;
9997import static org .hamcrest .Matchers .hasSize ;
10098import static org .hamcrest .Matchers .instanceOf ;
99+ import static org .hamcrest .Matchers .is ;
101100import static org .hamcrest .Matchers .isA ;
102101import static org .mockito .Mockito .mock ;
103102
104103public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests {
104+ private static final String URL = "http://www.abc.com" ;
105+ private static final String MODEL_ID = "model_id" ;
106+ private static final String USER_ROLE = "user" ;
107+ private static final String API_KEY = "secret" ;
108+ private static final String INFERENCE_ID = "id" ;
105109 private final MockWebServer webServer = new MockWebServer ();
106110 private ThreadPool threadPool ;
107111 private HttpClientManager clientManager ;
@@ -168,32 +172,32 @@ private static void assertModel(Model model, TaskType taskType, boolean modelInc
168172 private static void assertTextEmbeddingModel (Model model , boolean modelIncludesSecrets ) {
169173 var openShiftAiModel = assertCommonModelFields (model , modelIncludesSecrets );
170174
171- assertThat (openShiftAiModel .getTaskType (), Matchers . is (TaskType .TEXT_EMBEDDING ));
175+ assertThat (openShiftAiModel .getTaskType (), is (TaskType .TEXT_EMBEDDING ));
172176 }
173177
174178 private static OpenShiftAiModel assertCommonModelFields (Model model , boolean modelIncludesSecrets ) {
175179 assertThat (model , instanceOf (OpenShiftAiModel .class ));
176180
177181 var openShiftAiModel = (OpenShiftAiModel ) model ;
178- assertThat (openShiftAiModel .getServiceSettings ().modelId (), is ("model_id" ));
179- assertThat (openShiftAiModel .getServiceSettings ().uri .toString (), Matchers . is ("http://www.abc.com" ));
180- assertThat (openShiftAiModel .getTaskSettings (), Matchers . is (EmptyTaskSettings .INSTANCE ));
182+ assertThat (openShiftAiModel .getServiceSettings ().modelId (), is (MODEL_ID ));
183+ assertThat (openShiftAiModel .getServiceSettings ().uri .toString (), is (URL ));
184+ assertThat (openShiftAiModel .getTaskSettings (), is (EmptyTaskSettings .INSTANCE ));
181185
182186 if (modelIncludesSecrets ) {
183- assertThat (openShiftAiModel .getSecretSettings ().apiKey (), Matchers . is (new SecureString ("secret" .toCharArray ())));
187+ assertThat (openShiftAiModel .getSecretSettings ().apiKey (), is (new SecureString (API_KEY .toCharArray ())));
184188 }
185189
186190 return openShiftAiModel ;
187191 }
188192
189193 private static void assertCompletionModel (Model model , boolean modelIncludesSecrets ) {
190194 var openShiftAiModel = assertCommonModelFields (model , modelIncludesSecrets );
191- assertThat (openShiftAiModel .getTaskType (), Matchers . is (TaskType .COMPLETION ));
195+ assertThat (openShiftAiModel .getTaskType (), is (TaskType .COMPLETION ));
192196 }
193197
194198 private static void assertChatCompletionModel (Model model , boolean modelIncludesSecrets ) {
195199 var openShiftAiModel = assertCommonModelFields (model , modelIncludesSecrets );
196- assertThat (openShiftAiModel .getTaskType (), Matchers . is (TaskType .CHAT_COMPLETION ));
200+ assertThat (openShiftAiModel .getTaskType (), is (TaskType .CHAT_COMPLETION ));
197201 }
198202
199203 public static SenderService createService (ThreadPool threadPool , HttpClientManager clientManager ) {
@@ -202,9 +206,7 @@ public static SenderService createService(ThreadPool threadPool, HttpClientManag
202206 }
203207
204208 private static Map <String , Object > createServiceSettingsMap (TaskType taskType ) {
205- Map <String , Object > settingsMap = new HashMap <>(
206- Map .of (ServiceFields .URL , "http://www.abc.com" , ServiceFields .MODEL_ID , "model_id" )
207- );
209+ Map <String , Object > settingsMap = new HashMap <>(Map .of (ServiceFields .URL , URL , ServiceFields .MODEL_ID , MODEL_ID ));
208210
209211 if (taskType == TaskType .TEXT_EMBEDDING ) {
210212 settingsMap .putAll (
@@ -223,27 +225,17 @@ private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
223225 }
224226
225227 private static Map <String , Object > createSecretSettingsMap () {
226- return new HashMap <>(Map .of ("api_key" , "secret" ));
228+ return new HashMap <>(Map .of ("api_key" , API_KEY ));
227229 }
228230
229231 private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel (@ Nullable SimilarityMeasure similarityMeasure ) {
230- var inferenceId = "inference_id" ;
231-
232232 return new OpenShiftAiEmbeddingsModel (
233- inferenceId ,
233+ INFERENCE_ID ,
234234 TaskType .TEXT_EMBEDDING ,
235235 OpenShiftAiService .NAME ,
236- new OpenShiftAiEmbeddingsServiceSettings (
237- "model_id" ,
238- "http://www.abc.com" ,
239- 1536 ,
240- similarityMeasure ,
241- 512 ,
242- new RateLimitSettings (10_000 ),
243- true
244- ),
236+ new OpenShiftAiEmbeddingsServiceSettings (MODEL_ID , URL , 1536 , similarityMeasure , 512 , new RateLimitSettings (10_000 ), true ),
245237 createRandomChunkingSettings (),
246- new DefaultSecretSettings (new SecureString ("secret" .toCharArray ()))
238+ new DefaultSecretSettings (new SecureString (API_KEY .toCharArray ()))
247239 );
248240 }
249241
@@ -267,19 +259,15 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP
267259 assertThat (model , instanceOf (OpenShiftAiEmbeddingsModel .class ));
268260
269261 var embeddingsModel = (OpenShiftAiEmbeddingsModel ) model ;
270- assertThat (embeddingsModel .getServiceSettings ().uri ().toString (), is ("url" ));
262+ assertThat (embeddingsModel .getServiceSettings ().uri ().toString (), is (URL ));
271263 assertThat (embeddingsModel .getConfigurations ().getChunkingSettings (), instanceOf (ChunkingSettings .class ));
272- assertThat (embeddingsModel .getSecretSettings ().apiKey ().toString (), is ("secret" ));
264+ assertThat (embeddingsModel .getSecretSettings ().apiKey ().toString (), is (API_KEY ));
273265 }, e -> fail ("parse request should not fail " + e .getMessage ()));
274266
275267 service .parseRequestConfig (
276- "id" ,
268+ INFERENCE_ID ,
277269 TaskType .TEXT_EMBEDDING ,
278- getRequestConfigMap (
279- getServiceSettingsMap ("model" , "url" ),
280- createRandomChunkingSettingsMap (),
281- getSecretSettingsMap ("secret" )
282- ),
270+ getRequestConfigMap (getServiceSettingsMap (MODEL_ID , URL ), createRandomChunkingSettingsMap (), getSecretSettingsMap (API_KEY )),
283271 modelVerificationActionListener
284272 );
285273 }
@@ -291,49 +279,43 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN
291279 assertThat (model , instanceOf (OpenShiftAiEmbeddingsModel .class ));
292280
293281 var embeddingsModel = (OpenShiftAiEmbeddingsModel ) model ;
294- assertThat (embeddingsModel .getServiceSettings ().uri ().toString (), is ("url" ));
282+ assertThat (embeddingsModel .getServiceSettings ().uri ().toString (), is (URL ));
295283 assertThat (embeddingsModel .getConfigurations ().getChunkingSettings (), instanceOf (ChunkingSettings .class ));
296- assertThat (embeddingsModel .getSecretSettings ().apiKey ().toString (), is ("secret" ));
284+ assertThat (embeddingsModel .getSecretSettings ().apiKey ().toString (), is (API_KEY ));
297285 }, e -> fail ("parse request should not fail " + e .getMessage ()));
298286
299287 service .parseRequestConfig (
300- "id" ,
288+ INFERENCE_ID ,
301289 TaskType .TEXT_EMBEDDING ,
302- getRequestConfigMap (getServiceSettingsMap ("model" , "url" ), getSecretSettingsMap ("secret" )),
290+ getRequestConfigMap (getServiceSettingsMap (MODEL_ID , URL ), getSecretSettingsMap (API_KEY )),
303291 modelVerificationActionListener
304292 );
305293 }
306294 }
307295
308296 public void testParseRequestConfig_WithoutModelId_Success () throws IOException {
309- var url = "url" ;
310- var secret = "secret" ;
311-
312297 try (var service = createService ()) {
313298 ActionListener <Model > modelVerificationListener = ActionListener .wrap (m -> {
314299 assertThat (m , instanceOf (OpenShiftAiChatCompletionModel .class ));
315300
316301 var chatCompletionModel = (OpenShiftAiChatCompletionModel ) m ;
317302
318- assertThat (chatCompletionModel .getServiceSettings ().uri ().toString (), is (url ));
303+ assertThat (chatCompletionModel .getServiceSettings ().uri ().toString (), is (URL ));
319304 assertNull (chatCompletionModel .getServiceSettings ().modelId ());
320- assertThat (chatCompletionModel .getSecretSettings ().apiKey ().toString (), is ("secret" ));
305+ assertThat (chatCompletionModel .getSecretSettings ().apiKey ().toString (), is (API_KEY ));
321306
322307 }, e -> fail ("parse request should not fail " + e .getMessage ()));
323308
324309 service .parseRequestConfig (
325- "id" ,
310+ INFERENCE_ID ,
326311 TaskType .CHAT_COMPLETION ,
327- getRequestConfigMap (getServiceSettingsMap (null , url ), getSecretSettingsMap (secret )),
312+ getRequestConfigMap (getServiceSettingsMap (null , URL ), getSecretSettingsMap (API_KEY )),
328313 modelVerificationListener
329314 );
330315 }
331316 }
332317
333318 public void testParseRequestConfig_WithoutUrl_ThrowsException () throws IOException {
334- var model = "model" ;
335- var secret = "secret" ;
336-
337319 try (var service = createService ()) {
338320 ActionListener <Model > modelVerificationListener = ActionListener .wrap (
339321 m -> fail ("Expected exception, but got model: " + m ),
@@ -347,9 +329,9 @@ public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOExcepti
347329 );
348330
349331 service .parseRequestConfig (
350- "id" ,
332+ INFERENCE_ID ,
351333 TaskType .CHAT_COMPLETION ,
352- getRequestConfigMap (getServiceSettingsMap (model , null ), getSecretSettingsMap (secret )),
334+ getRequestConfigMap (getServiceSettingsMap (MODEL_ID , null ), getSecretSettingsMap (API_KEY )),
353335 modelVerificationListener
354336 );
355337 }
@@ -386,12 +368,14 @@ public void testUnifiedCompletionInfer() throws Exception {
386368
387369 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
388370 try (var service = new OpenShiftAiService (senderFactory , createWithEmptySettings (threadPool ), mockClusterServiceEmpty ())) {
389- var model = createChatCompletionModel (getUrl (webServer ), "secret" , "model" );
371+ var model = createChatCompletionModel (getUrl (webServer ), API_KEY , MODEL_ID );
390372 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
391373 service .unifiedCompletionInfer (
392374 model ,
393375 UnifiedCompletionRequest .of (
394- List .of (new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), "user" , null , null ))
376+ List .of (
377+ new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), USER_ROLE , null , null )
378+ )
395379 ),
396380 InferenceAction .Request .DEFAULT_TIMEOUT ,
397381 listener
@@ -426,12 +410,14 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception {
426410
427411 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
428412 try (var service = new OpenShiftAiService (senderFactory , createWithEmptySettings (threadPool ), mockClusterServiceEmpty ())) {
429- var model = OpenShiftAiChatCompletionModelTests .createChatCompletionModel (getUrl (webServer ), "secret" , "model" );
413+ var model = OpenShiftAiChatCompletionModelTests .createChatCompletionModel (getUrl (webServer ), API_KEY , MODEL_ID );
430414 var latch = new CountDownLatch (1 );
431415 service .unifiedCompletionInfer (
432416 model ,
433417 UnifiedCompletionRequest .of (
434- List .of (new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), "user" , null , null ))
418+ List .of (
419+ new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), USER_ROLE , null , null )
420+ )
435421 ),
436422 InferenceAction .Request .DEFAULT_TIMEOUT ,
437423 ActionListener .runAfter (ActionTestUtils .assertNoSuccessListener (e -> {
@@ -510,12 +496,14 @@ public void testInfer_StreamRequest() throws Exception {
510496 private void testStreamError (String expectedResponse ) throws Exception {
511497 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
512498 try (var service = new OpenShiftAiService (senderFactory , createWithEmptySettings (threadPool ), mockClusterServiceEmpty ())) {
513- var model = OpenShiftAiChatCompletionModelTests .createChatCompletionModel (getUrl (webServer ), "secret" , "model" );
499+ var model = OpenShiftAiChatCompletionModelTests .createChatCompletionModel (getUrl (webServer ), API_KEY , MODEL_ID );
514500 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
515501 service .unifiedCompletionInfer (
516502 model ,
517503 UnifiedCompletionRequest .of (
518- List .of (new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), "user" , null , null ))
504+ List .of (
505+ new UnifiedCompletionRequest .Message (new UnifiedCompletionRequest .ContentString ("hello" ), USER_ROLE , null , null )
506+ )
519507 ),
520508 InferenceAction .Request .DEFAULT_TIMEOUT ,
521509 listener
@@ -597,7 +585,7 @@ public void testSupportsStreaming() throws IOException {
597585
598586 public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap () throws IOException {
599587 try (var service = createService ()) {
600- var secretSettings = getSecretSettingsMap ("secret" );
588+ var secretSettings = getSecretSettingsMap (API_KEY );
601589 secretSettings .put ("extra_key" , "value" );
602590
603591 var config = getRequestConfigMap (getEmbeddingsServiceSettingsMap (), secretSettings );
@@ -613,21 +601,21 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe
613601 }
614602 );
615603
616- service .parseRequestConfig ("id" , TaskType .TEXT_EMBEDDING , config , modelVerificationListener );
604+ service .parseRequestConfig (INFERENCE_ID , TaskType .TEXT_EMBEDDING , config , modelVerificationListener );
617605 }
618606 }
619607
620608 public void testChunkedInfer_ChunkingSettingsNotSet () throws IOException {
621- var model = OpenShiftAiEmbeddingsModelTests .createModel (getUrl (webServer ), "api_key" , "model" , 1234 , false , 1536 , null );
609+ var model = OpenShiftAiEmbeddingsModelTests .createModel (getUrl (webServer ), API_KEY , MODEL_ID , 1234 , false , 1536 , null );
622610
623611 testChunkedInfer (model );
624612 }
625613
626614 public void testChunkedInfer_ChunkingSettingsSet () throws IOException {
627615 var model = OpenShiftAiEmbeddingsModelTests .createModel (
628616 getUrl (webServer ),
629- "api_key" ,
630- "model" ,
617+ API_KEY ,
618+ MODEL_ID ,
631619 1234 ,
632620 false ,
633621 1536 ,
@@ -691,7 +679,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
691679
692680 assertThat (results , hasSize (2 ));
693681 {
694- assertThat (results .get (0 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
682+ assertThat (results .get (0 ), Matchers .instanceOf (ChunkedInferenceEmbedding .class ));
695683 var floatResult = (ChunkedInferenceEmbedding ) results .get (0 );
696684 assertThat (floatResult .chunks (), hasSize (1 ));
697685 assertThat (floatResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (DenseEmbeddingFloatResults .Embedding .class ));
@@ -703,7 +691,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
703691 );
704692 }
705693 {
706- assertThat (results .get (1 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
694+ assertThat (results .get (1 ), Matchers .instanceOf (ChunkedInferenceEmbedding .class ));
707695 var floatResult = (ChunkedInferenceEmbedding ) results .get (1 );
708696 assertThat (floatResult .chunks (), hasSize (1 ));
709697 assertThat (floatResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (DenseEmbeddingFloatResults .Embedding .class ));
@@ -721,12 +709,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
721709 webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ),
722710 equalTo (XContentType .JSON .mediaTypeWithoutParameters ())
723711 );
724- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), equalTo ("Bearer api_key " ));
712+ assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), equalTo ("Bearer secret " ));
725713
726714 var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
727- assertThat (requestMap .size (), Matchers . is (2 ));
728- assertThat (requestMap .get ("input" ), Matchers . is (List .of ("abc" , "def" )));
729- assertThat (requestMap .get ("model" ), Matchers . is ("model" ));
715+ assertThat (requestMap .size (), is (2 ));
716+ assertThat (requestMap .get ("input" ), is (List .of ("abc" , "def" )));
717+ assertThat (requestMap .get ("model" ), is (MODEL_ID ));
730718 }
731719 }
732720
@@ -795,7 +783,7 @@ public void testGetConfiguration() throws Exception {
795783 private InferenceEventsAssertion streamCompletion () throws Exception {
796784 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
797785 try (var service = new OpenShiftAiService (senderFactory , createWithEmptySettings (threadPool ), mockClusterServiceEmpty ())) {
798- var model = OpenShiftAiChatCompletionModelTests .createCompletionModel (getUrl (webServer ), "secret" , "model" );
786+ var model = OpenShiftAiChatCompletionModelTests .createCompletionModel (getUrl (webServer ), API_KEY , MODEL_ID );
799787 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
800788 service .infer (
801789 model ,
@@ -842,7 +830,7 @@ private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSetti
842830 }
843831
844832 private static Map <String , Object > getEmbeddingsServiceSettingsMap () {
845- return buildServiceSettingsMap ("id" , "url" , SimilarityMeasure .COSINE .toString (), null , null , null );
833+ return buildServiceSettingsMap (INFERENCE_ID , URL , SimilarityMeasure .COSINE .toString (), null , null , null );
846834 }
847835
848836 @ Override
0 commit comments