2727import org .elasticsearch .xpack .inference .external .http .sender .ChatCompletionInput ;
2828import org .elasticsearch .xpack .inference .external .http .sender .EmbeddingsInput ;
2929import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests ;
30+ import org .elasticsearch .xpack .inference .external .http .sender .Sender ;
3031import org .elasticsearch .xpack .inference .logging .ThrottlerManager ;
3132import org .elasticsearch .xpack .inference .services .ServiceComponents ;
3233import org .elasticsearch .xpack .inference .services .huggingface .completion .HuggingFaceChatCompletionModelTests ;
@@ -462,31 +463,13 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro
462463 """ ;
463464 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
464465
465- var model = HuggingFaceChatCompletionModelTests .createCompletionModel (getUrl (webServer ), "secret" , "model" );
466- var actionCreator = new HuggingFaceActionCreator (sender , createWithEmptySettings (threadPool ));
467- var action = actionCreator .create (model );
468-
469- PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
470- action .execute (new ChatCompletionInput (List .of ("Hello" ), false ), InferenceAction .Request .DEFAULT_TIMEOUT , listener );
466+ PlainActionFuture <InferenceServiceResults > listener = createChatCompletionFuture (sender , createWithEmptySettings (threadPool ));
471467
472468 var result = listener .actionGet (TIMEOUT );
473469
474470 assertThat (result .asMap (), is (buildExpectationCompletion (List .of ("Hello there, how may I assist you today?" ))));
475471
476- assertThat (webServer .requests (), hasSize (1 ));
477- assertNull (webServer .requests ().get (0 ).getUri ().getQuery ());
478- assertThat (
479- webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ),
480- equalTo (XContentType .JSON .mediaTypeWithoutParameters ())
481- );
482- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), equalTo ("Bearer secret" ));
483-
484- var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
485- assertThat (requestMap .size (), is (4 ));
486- assertThat (requestMap .get ("messages" ), is (List .of (Map .of ("role" , "user" , "content" , "Hello" ))));
487- assertThat (requestMap .get ("model" ), is ("model" ));
488- assertThat (requestMap .get ("n" ), is (1 ));
489- assertThat (requestMap .get ("stream" ), is (false ));
472+ assertChatCompletionRequest ();
490473 }
491474 }
492475
@@ -508,36 +491,45 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th
508491 """ ;
509492 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
510493
511- var model = HuggingFaceChatCompletionModelTests .createCompletionModel (getUrl (webServer ), "secret" , "model" );
512- var actionCreator = new HuggingFaceActionCreator (
494+ PlainActionFuture <InferenceServiceResults > listener = createChatCompletionFuture (
513495 sender ,
514496 new ServiceComponents (threadPool , mockThrottlerManager (), settings , TruncatorTests .createTruncator ())
515497 );
516- var action = actionCreator .create (model );
517-
518- PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
519- action .execute (new ChatCompletionInput (List .of ("Hello" ), false ), InferenceAction .Request .DEFAULT_TIMEOUT , listener );
520498
521499 var thrownException = expectThrows (ElasticsearchException .class , () -> listener .actionGet (TIMEOUT ));
522500 assertThat (
523501 thrownException .getMessage (),
524502 is ("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]" )
525503 );
526504
527- assertThat (webServer .requests (), hasSize (1 ));
528- assertNull (webServer .requests ().get (0 ).getUri ().getQuery ());
529- assertThat (
530- webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ),
531- equalTo (XContentType .JSON .mediaTypeWithoutParameters ())
532- );
533- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), equalTo ("Bearer secret" ));
534-
535- var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
536- assertThat (requestMap .size (), is (4 ));
537- assertThat (requestMap .get ("messages" ), is (List .of (Map .of ("role" , "user" , "content" , "Hello" ))));
538- assertThat (requestMap .get ("model" ), is ("model" ));
539- assertThat (requestMap .get ("n" ), is (1 ));
540- assertThat (requestMap .get ("stream" ), is (false ));
505+ assertChatCompletionRequest ();
541506 }
542507 }
508+
509+ private PlainActionFuture <InferenceServiceResults > createChatCompletionFuture (Sender sender , ServiceComponents threadPool ) {
510+ var model = HuggingFaceChatCompletionModelTests .createCompletionModel (getUrl (webServer ), "secret" , "model" );
511+ var actionCreator = new HuggingFaceActionCreator (sender , threadPool );
512+ var action = actionCreator .create (model );
513+
514+ PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
515+ action .execute (new ChatCompletionInput (List .of ("Hello" ), false ), InferenceAction .Request .DEFAULT_TIMEOUT , listener );
516+ return listener ;
517+ }
518+
519+ private void assertChatCompletionRequest () throws IOException {
520+ assertThat (webServer .requests (), hasSize (1 ));
521+ assertNull (webServer .requests ().get (0 ).getUri ().getQuery ());
522+ assertThat (
523+ webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ),
524+ equalTo (XContentType .JSON .mediaTypeWithoutParameters ())
525+ );
526+ assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), equalTo ("Bearer secret" ));
527+
528+ var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
529+ assertThat (requestMap .size (), is (4 ));
530+ assertThat (requestMap .get ("messages" ), is (List .of (Map .of ("role" , "user" , "content" , "Hello" ))));
531+ assertThat (requestMap .get ("model" ), is ("model" ));
532+ assertThat (requestMap .get ("n" ), is (1 ));
533+ assertThat (requestMap .get ("stream" ), is (false ));
534+ }
543535}
0 commit comments