4747import  org .elasticsearch .test .ESIntegTestCase ;
4848import  org .elasticsearch .xcontent .ToXContent ;
4949import  org .elasticsearch .xpack .core .inference .action .InferenceAction ;
50+ import  org .elasticsearch .xpack .inference .external .response .streaming .ServerSentEvent ;
51+ import  org .elasticsearch .xpack .inference .external .response .streaming .ServerSentEventField ;
52+ import  org .elasticsearch .xpack .inference .external .response .streaming .ServerSentEventParser ;
5053
5154import  java .io .IOException ;
5255import  java .nio .charset .StandardCharsets ;
5962import  java .util .concurrent .Flow ;
6063import  java .util .concurrent .LinkedBlockingDeque ;
6164import  java .util .concurrent .TimeUnit ;
62- import  java .util .concurrent .atomic .AtomicBoolean ;
6365import  java .util .concurrent .atomic .AtomicInteger ;
6466import  java .util .concurrent .atomic .AtomicReference ;
6567import  java .util .function .Predicate ;
@@ -80,9 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
8082    private  static  final  String  NO_STREAM_ROUTE  = "/_inference_no_stream" ;
8183    private  static  final  Exception  expectedException  = new  IllegalStateException ("hello there" );
8284    private  static  final  String  expectedExceptionAsServerSentEvent  = """ 
83-         \uFEFF \ 
84-         event: error 
85-         data: {\ 
85+         {\ 
8686        "error":{"root_cause":[{"type":"illegal_state_exception","reason":"hello there",\ 
8787        "caused_by":{"type":"illegal_state_exception","reason":"hello there"}}],\ 
8888        "type":"illegal_state_exception","reason":"hello there"},"status":500\ 
@@ -323,30 +323,16 @@ protected void releaseResources() {}
323323    }
324324
325325    private  static  class  RandomStringCollector  {
326-         private  static  final  Pattern  jsonPattern  = Pattern .compile ("^\uFEFF event: message\n data: \\ {.*}$" );
327-         private  static  final  Pattern  endPattern  = Pattern .compile ("^\uFEFF event: message\n data: \\ [DONE\\ ]$" );
328-         private  final  AtomicBoolean  hasDoneChunk  = new  AtomicBoolean (false );
329326        private  final  Deque <String > stringsVerified  = new  LinkedBlockingDeque <>();
330-         private  volatile   String   previousTokens  = "" ;
327+         private  final   ServerSentEventParser   sseParser  = new   ServerSentEventParser () ;
331328
332329        private  void  collect (String  str ) throws  IOException  {
333-             str  = previousTokens  + str ;
334-             String [] events  = str .split ("\n \n " , -1 );
335-             for  (var  i  = 0 ; i  < events .length  - 1 ; i ++) {
336-                 var  line  = events [i ];
337-                 if  (jsonPattern .matcher (line ).matches () || expectedExceptionAsServerSentEvent .equals (line )) {
338-                     stringsVerified .offer (line );
339-                 } else  if  (endPattern .matcher (line ).matches ()) {
340-                     hasDoneChunk .set (true );
341-                 } else  {
342-                     throw  new  IOException ("Line does not match expected JSON message or DONE message. Line: "  + line );
343-                 }
344-             }
345- 
346-             previousTokens  = events [events .length  - 1 ];
347-             if  (endPattern .matcher (previousTokens .trim ()).matches ()) {
348-                 hasDoneChunk .set (true );
349-             }
330+             sseParser .parse (str .getBytes (StandardCharsets .UTF_8 ))
331+                 .stream ()
332+                 .filter (event  -> event .name () == ServerSentEventField .DATA )
333+                 .filter (ServerSentEvent ::hasValue )
334+                 .map (ServerSentEvent ::value )
335+                 .forEach (stringsVerified ::offer );
350336        }
351337    }
352338
@@ -363,8 +349,8 @@ public void testResponse() {
363349
364350        var  response  = callAsync (request );
365351        assertThat (response .getStatusLine ().getStatusCode (), is (HttpStatus .SC_OK ));
366-         assertThat (collector .stringsVerified .size (), equalTo (expectedTestCount ));
367-         assertThat (collector .hasDoneChunk . get (), equalTo (true ));
352+         assertThat (collector .stringsVerified .size (), equalTo (expectedTestCount  +  1 ));  // normal payload count + done byte 
353+         assertThat (collector .stringsVerified . peekLast (), equalTo ("[DONE]" ));
368354    }
369355
370356    private  Response  callAsync (Request  request ) {
@@ -409,10 +395,9 @@ public void testOnFailure() throws IOException {
409395        } catch  (ResponseException  e ) {
410396            var  response  = e .getResponse ();
411397            assertThat (response .getStatusLine ().getStatusCode (), is (HttpStatus .SC_INTERNAL_SERVER_ERROR ));
412-             assertThat (
413-                 EntityUtils .toString (response .getEntity (), StandardCharsets .UTF_8 ),
414-                 equalTo (expectedExceptionAsServerSentEvent  + "\n \n " )
415-             );
398+             assertThat (EntityUtils .toString (response .getEntity (), StandardCharsets .UTF_8 ), equalTo (""" 
399+                 \uFEFF event: error 
400+                 data:\s """  + expectedExceptionAsServerSentEvent  + "\n \n " ));
416401        }
417402    }
418403
@@ -431,7 +416,7 @@ public void testErrorMidStream() {
431416        var  response  = callAsync (request );
432417        assertThat (response .getStatusLine ().getStatusCode (), is (HttpStatus .SC_OK )); // error still starts with 200-OK 
433418        assertThat (collector .stringsVerified .size (), equalTo (expectedTestCount  + 1 )); // normal payload count + last error byte 
434-         assertThat ("DONE chunk is not sent on error" , collector .hasDoneChunk . get ( ), equalTo (false ));
419+         assertThat ("DONE chunk is not sent on error" , collector .stringsVerified . stream (). anyMatch ( "[DONE]" :: equals ), equalTo (false ));
435420        assertThat (collector .stringsVerified .getLast (), equalTo (expectedExceptionAsServerSentEvent ));
436421    }
437422
0 commit comments