1616import org .springframework .web .util .UriBuilderFactory ;
1717
1818import javax .websocket .ClientEndpointConfig ;
19+ import javax .websocket .CloseReason ;
1920import javax .websocket .ContainerProvider ;
2021import javax .websocket .Endpoint ;
2122import javax .websocket .EndpointConfig ;
3334import java .util .Map ;
3435import java .util .Optional ;
3536import java .util .Queue ;
36- import java .util .concurrent .ConcurrentLinkedQueue ;
3737import java .util .concurrent .atomic .AtomicInteger ;
38+ import java .util .function .Predicate ;
3839
3940import static org .assertj .core .api .Assertions .assertThat ;
4041import static org .junit .jupiter .api .Assertions .fail ;
4647@ Slf4j
4748public class GraphQLTestSubscription {
4849
50+ private static final WebSocketContainer WEB_SOCKET_CONTAINER = ContainerProvider .getWebSocketContainer ();
4951 private static final int SLEEP_INTERVAL_MS = 100 ;
50- private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 6000000 ;
52+ private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 60000 ;
5153 private static final AtomicInteger ID_COUNTER = new AtomicInteger (1 );
5254 private static final UriBuilderFactory URI_BUILDER_FACTORY = new DefaultUriBuilderFactory ();
55+ private static final Object STATE_LOCK = new Object ();
5356
5457 @ Getter
5558 private Session session ;
56-
57- @ Getter
58- private boolean initialized = false ;
59- @ Getter
60- private boolean acknowledged = false ;
61- @ Getter
62- private boolean started = false ;
63- @ Getter
64- private boolean stopped = false ;
59+ private SubscriptionState state = SubscriptionState .builder ()
60+ .id (ID_COUNTER .incrementAndGet ())
61+ .build ();
6562
6663 private final Environment environment ;
6764 private final ObjectMapper objectMapper ;
6865 private final String subscriptionPath ;
6966
70- private final Queue <GraphQLResponse > responses = new ConcurrentLinkedQueue <>();
71- private int id = ID_COUNTER .getAndIncrement ();
67+ public boolean isInitialized () {
68+ return state .isInitialized ();
69+ }
70+
71+ public boolean isAcknowledged () {
72+ return state .isAcknowledged ();
73+ }
74+
75+ public boolean isStarted () {
76+ return state .isStarted ();
77+ }
78+
79+ public boolean isStopped () {
80+ return state .isStopped ();
81+ }
82+
83+ public boolean isCompleted () {
84+ return state .isCompleted ();
85+ }
7286
7387 /**
7488 * Sends the "connection_init" message to the GraphQL server without a payload.
@@ -85,7 +99,7 @@ public GraphQLTestSubscription init() {
8599 * @return self reference
86100 */
87101 public GraphQLTestSubscription init (@ Nullable final Object payload ) {
88- if (initialized ) {
102+ if (isInitialized () ) {
89103 fail ("Subscription already initialized." );
90104 }
91105 try {
@@ -97,8 +111,9 @@ public GraphQLTestSubscription init(@Nullable final Object payload) {
97111 message .put ("type" , "connection_init" );
98112 message .set ("payload" , getFinalPayload (payload ));
99113 sendMessage (message );
100- initialized = true ;
114+ state . setInitialized ( true ) ;
101115 awaitAcknowledgement ();
116+ log .debug ("Subscription successfully initialized." );
102117 return this ;
103118 }
104119
@@ -120,20 +135,21 @@ public GraphQLTestSubscription start(@NonNull final String graphQLResource) {
120135 * @return self reference
121136 */
122137 public GraphQLTestSubscription start (@ NonNull final String graphGLResource , @ Nullable final Object variables ) {
123- if (!initialized ) {
138+ if (!isInitialized () ) {
124139 init ();
125140 }
126- if (started ) {
141+ if (isStarted () ) {
127142 fail ("Start message already sent. To start a new subscription, please call reset first." );
128143 }
129- started = true ;
144+ state . setStarted ( true ) ;
130145 ObjectNode payload = objectMapper .createObjectNode ();
131146 payload .put ("query" , loadQuery (graphGLResource ));
132147 payload .set ("variables" , getFinalPayload (variables ));
133148 ObjectNode message = objectMapper .createObjectNode ();
134149 message .put ("type" , "start" );
135- message .put ("id" , id );
150+ message .put ("id" , state . getId () );
136151 message .set ("payload" , payload );
152+ log .debug ("Sending start message." );
137153 sendMessage (message );
138154 return this ;
139155 }
@@ -143,24 +159,25 @@ public GraphQLTestSubscription start(@NonNull final String graphGLResource, @Nul
143159 * @return self reference
144160 */
145161 public GraphQLTestSubscription stop () {
146- if (!initialized ) {
162+ if (!isInitialized () ) {
147163 fail ("Subscription not yet initialized." );
148164 }
149- if (stopped ) {
165+ if (isStopped () ) {
150166 fail ("Subscription already stopped." );
151167 }
152168 final ObjectNode message = objectMapper .createObjectNode ();
153169 message .put ("type" , "stop" );
154- message .put ("id" , id );
170+ message .put ("id" , state .getId ());
171+ log .debug ("Sending stop message." );
155172 sendMessage (message );
156- stopped = true ;
157173 try {
174+ log .debug ("Closing web socket session." );
158175 session .close ();
159- session = null ;
176+ awaitStop ();
177+ log .debug ("Web socket session closed." );
160178 } catch (IOException e ) {
161179 fail ("Could not close web socket session" , e );
162180 }
163- log .debug ("Subscription stopped." );
164181 return this ;
165182 }
166183
@@ -169,20 +186,12 @@ public GraphQLTestSubscription stop() {
169186 * ensure that the bean is reusable between tests.
170187 */
171188 public void reset () {
172- if (initialized && !stopped ) {
189+ if (isInitialized () && !isStopped () ) {
173190 stop ();
174191 }
175- if (stopped ) {
176- id = ID_COUNTER .getAndIncrement ();
177- }
178- initialized = false ;
179- started = false ;
180- stopped = false ;
181- acknowledged = false ;
192+ state = SubscriptionState .builder ().id (ID_COUNTER .incrementAndGet ()).build ();
182193 session = null ;
183- synchronized (responses ) {
184- responses .clear ();
185- }
194+ log .debug ("Test subscription client reset." );
186195 }
187196
188197 /**
@@ -264,15 +273,15 @@ public List<GraphQLResponse> awaitAndGetNextResponses(
264273 final int numExpectedResponses ,
265274 final boolean stopAfter
266275 ) {
267- if (!started ) {
276+ if (!isStarted () ) {
268277 fail ("Start message not sent. Please send start message first." );
269278 }
270- if (stopped ) {
279+ if (isStopped () ) {
271280 fail ("Subscription already stopped. Forgot to call reset after test case?" );
272281 }
273282 int elapsedTime = 0 ;
274283 while (
275- ((responses .size () < numExpectedResponses ) || numExpectedResponses <= 0 )
284+ ((state . getResponses () .size () < numExpectedResponses ) || numExpectedResponses <= 0 )
276285 && elapsedTime < timeout
277286 ) {
278287 try {
@@ -282,10 +291,11 @@ public List<GraphQLResponse> awaitAndGetNextResponses(
282291 fail ("Test execution error - Thread.sleep failed." , e );
283292 }
284293 }
285- synchronized (responses ) {
286- if (stopAfter ) {
287- stop ();
288- }
294+ if (stopAfter ) {
295+ stop ();
296+ }
297+ synchronized (STATE_LOCK ) {
298+ final Queue <GraphQLResponse > responses = state .getResponses ();
289299 int responsesToPoll = responses .size ();
290300 if (numExpectedResponses == 0 ) {
291301 assertThat (responses )
@@ -336,16 +346,15 @@ public GraphQLTestSubscription waitAndExpectNoResponse(final int timeToWait) {
336346 * @return the remaining responses.
337347 */
338348 public List <GraphQLResponse > getRemainingResponses () {
339- if (!stopped ) {
349+ if (!isStopped () ) {
340350 fail ("getRemainingResponses should only be called after the subscription was stopped." );
341351 }
342- final ArrayList <GraphQLResponse > graphQLResponses = new ArrayList <>(responses );
343- responses .clear ();
352+ final ArrayList <GraphQLResponse > graphQLResponses = new ArrayList <>(state . getResponses () );
353+ state . getResponses () .clear ();
344354 return graphQLResponses ;
345355 }
346356
347357 private void initClient () throws Exception {
348- final WebSocketContainer webSocketContainer = ContainerProvider .getWebSocketContainer ();
349358 final String port = environment .getProperty ("local.server.port" );
350359 final URI uri = URI_BUILDER_FACTORY .builder ().scheme ("ws" ).host ("localhost" ).port (port ).path (subscriptionPath )
351360 .build ();
@@ -355,8 +364,8 @@ private void initClient() throws Exception {
355364 .build ();
356365 clientEndpointConfig .getUserProperties ().put ("org.apache.tomcat.websocket.IO_TIMEOUT_MS" ,
357366 String .valueOf (ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT ));
358- session = webSocketContainer .connectToServer (TestWebSocketClient . class , clientEndpointConfig , uri );
359- session .addMessageHandler (new TestMessageHandler ());
367+ session = WEB_SOCKET_CONTAINER .connectToServer (new TestWebSocketClient ( state ) , clientEndpointConfig , uri );
368+ session .addMessageHandler (new TestMessageHandler (objectMapper , state ));
360369 }
361370
362371 private JsonNode getFinalPayload (final Object variables ) {
@@ -384,8 +393,16 @@ private void sendMessage(final Object message) {
384393 }
385394
386395 private void awaitAcknowledgement () {
396+ await (GraphQLTestSubscription ::isAcknowledged , "Connection was not acknowledged by the GraphQL server." );
397+ }
398+
399+ private void awaitStop () {
400+ await (GraphQLTestSubscription ::isStopped , "Connection was not stopped in time." );
401+ }
402+
403+ private void await (final Predicate <GraphQLTestSubscription > condition , final String timeoutDescription ) {
387404 int elapsedTime = 0 ;
388- while (!acknowledged && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT ) {
405+ while (!condition . test ( this ) && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT ) {
389406 try {
390407 Thread .sleep (SLEEP_INTERVAL_MS );
391408 elapsedTime += SLEEP_INTERVAL_MS ;
@@ -394,31 +411,45 @@ private void awaitAcknowledgement() {
394411 }
395412 }
396413
397- if (!acknowledged ) {
398- fail ("Timeout: Connection was not acknowledged by the GraphQL server." );
414+ if (!condition . test ( this ) ) {
415+ fail (String . format ( "Timeout: " + timeoutDescription ) );
399416 }
400417 }
401418
402- class TestMessageHandler implements MessageHandler .Whole <String > {
419+ @ RequiredArgsConstructor
420+ static class TestMessageHandler implements MessageHandler .Whole <String > {
421+
422+ private final ObjectMapper objectMapper ;
423+ private final SubscriptionState state ;
424+
403425 @ Override
404426 public void onMessage (final String message ) {
405427 try {
406428 log .debug ("Received message from web socket: {}" , message );
407429 final JsonNode jsonNode = objectMapper .readTree (message );
408430 final JsonNode typeNode = jsonNode .get ("type" );
409- assertThat (typeNode .isNull ()).as ("GraphQL messages should have a type field." ).isFalse ();
431+ assertThat (typeNode ).as ("GraphQL messages should have a type field." ).isNotNull ();
432+ assertThat (typeNode .isNull ()).as ("GraphQL messages type should not be null." ).isFalse ();
410433 final String type = typeNode .asText ();
411- if (type .equals ("connection_ack" )) {
412- acknowledged = true ;
434+ if (type .equals ("complete" )) {
435+ state .setCompleted (true );
436+ log .debug ("Subscription completed." );
437+ } else if (type .equals ("connection_ack" )) {
438+ state .setAcknowledged (true );
413439 log .debug ("WebSocket connection acknowledged by the GraphQL Server." );
414440 } else if (type .equals ("data" ) || type .equals ("error" )) {
415441 final JsonNode payload = jsonNode .get ("payload" );
416442 assertThat (payload ).as ("Data/error messages must have a payload." ).isNotNull ();
417443 final String payloadString = objectMapper .writeValueAsString (payload );
418444 final GraphQLResponse graphQLResponse = new GraphQLResponse (ResponseEntity .ok (payloadString ),
419445 objectMapper );
420- synchronized (responses ) {
421- responses .add (graphQLResponse );
446+ if (state .isStopped () || state .isCompleted ()) {
447+ log .debug ("Response discarded because subscription was stopped or completed in the meanwhile." );
448+ } else {
449+ synchronized (STATE_LOCK ) {
450+ state .getResponses ().add (graphQLResponse );
451+ }
452+ log .debug ("New response recorded." );
422453 }
423454 }
424455 } catch (JsonProcessingException e ) {
@@ -427,11 +458,21 @@ public void onMessage(final String message) {
427458 }
428459 }
429460
430- public static class TestWebSocketClient extends Endpoint {
461+ @ RequiredArgsConstructor
462+ private static class TestWebSocketClient extends Endpoint {
463+
464+ private final SubscriptionState state ;
465+
431466 @ Override
432467 public void onOpen (final Session session , final EndpointConfig config ) {
433468 log .debug ("Connection established." );
434469 }
470+
471+ @ Override
472+ public void onClose (Session session , CloseReason closeReason ) {
473+ super .onClose (session , closeReason );
474+ state .setStopped (true );
475+ }
435476 }
436477
437478 static class TestWebSocketClientConfigurator extends ClientEndpointConfig .Configurator {
0 commit comments