Skip to content

Commit bb012e1

Browse files
authored
Merge pull request quarkusio#47990 from emattheis/userdata-in-connector
Add support for passing user data through client connectors in WebSockets Next
2 parents bb96f24 + 0369fba commit bb012e1

File tree

12 files changed

+159
-9
lines changed

12 files changed

+159
-9
lines changed

docs/src/main/asciidoc/websockets-next-reference.adoc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,49 @@ class MyEndpoint {
13981398
<1> `CoolService#isCool()` returns `Boolean` that is associated with the current connection.
13991399
<2> The `TypedKey.forBoolean("isCool")` is the key used to obtain the data stored when the connection was created.
14001400

1401+
===== Specify in connector
1402+
1403+
In some scenarios you may wish to associate user data with a connection to be created by a <<client-connectors,connector>>. In this case, you can set values on the connector instance prior to obtaining the connection. This is particularly useful if you need to do something when the connection is opened and the necessary context cannot be otherwise inferred.
1404+
1405+
.Connector
1406+
[source, java]
1407+
----
1408+
@Singleton
1409+
public class MyBean {
1410+
1411+
@Inject
1412+
MyService service;
1413+
1414+
@Inject
1415+
Instance<WebSocketConnector<MyEndpoint>> connectorInstance;
1416+
1417+
public void openAndSendMessage(String internalId, String message) {
1418+
var externalId = service.getExternalId(internalId);
1419+
var connection = connectorInstance.get()
1420+
.pathParam("externalId", externalId)
1421+
.userData(TypedKey.forString("internalId"), internalId)
1422+
.connectAndAwait();
1423+
connection.sendTextAndAwait(message);
1424+
}
1425+
}
1426+
----
1427+
1428+
.Endpoint
1429+
[source,java]
1430+
----
1431+
@WebSocketClient(path = "/endpoint/{externalId}")
1432+
class MyEndpoint {
1433+
1434+
@Inject
1435+
MyService service;
1436+
1437+
@OnOpen
1438+
void open(WebSocketClientConnection connection) {
1439+
var internalId = connection.userData().get(TypedKey.forString("internalId"));
1440+
service.doSomething(internalId);
1441+
}
1442+
}
1443+
----
14011444

14021445
[[client-cdi-events]]
14031446
==== CDI events

extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.concurrent.CopyOnWriteArrayList;
1313
import java.util.concurrent.CountDownLatch;
1414
import java.util.concurrent.TimeUnit;
15+
import java.util.concurrent.atomic.AtomicReference;
1516

1617
import jakarta.inject.Inject;
1718

@@ -26,6 +27,8 @@
2627
import io.quarkus.websockets.next.OnOpen;
2728
import io.quarkus.websockets.next.OnTextMessage;
2829
import io.quarkus.websockets.next.PathParam;
30+
import io.quarkus.websockets.next.UserData;
31+
import io.quarkus.websockets.next.UserData.TypedKey;
2932
import io.quarkus.websockets.next.WebSocket;
3033
import io.quarkus.websockets.next.WebSocketClientConnection;
3134
import io.vertx.core.Context;
@@ -63,13 +66,23 @@ void testClient() throws InterruptedException {
6366
assertThrows(NullPointerException.class, () -> connector.onPong(null));
6467
assertThrows(NullPointerException.class, () -> connector.onError(null));
6568

69+
CountDownLatch openLatch = new CountDownLatch(1);
70+
AtomicReference<UserData> userData = new AtomicReference<>();
6671
CountDownLatch messageLatch = new CountDownLatch(2);
6772
List<String> messages = new CopyOnWriteArrayList<>();
6873
CountDownLatch closedLatch = new CountDownLatch(1);
6974
WebSocketClientConnection connection1 = connector
7075
.baseUri(uri)
7176
.path("/{name}")
7277
.pathParam("name", "Lu")
78+
.userData(TypedKey.forBoolean("boolean"), true)
79+
.userData(TypedKey.forInt("int"), Integer.MAX_VALUE)
80+
.userData(TypedKey.forLong("long"), Long.MAX_VALUE)
81+
.userData(TypedKey.forString("string"), "Lu")
82+
.onOpen(c -> {
83+
userData.set(c.userData());
84+
openLatch.countDown();
85+
})
7386
.onTextMessage((c, m) -> {
7487
assertTrue(Context.isOnWorkerThread());
7588
String name = c.pathParam("name");
@@ -79,8 +92,19 @@ void testClient() throws InterruptedException {
7992
.onClose((c, s) -> closedLatch.countDown())
8093
.connectAndAwait();
8194
assertEquals("Lu", connection1.pathParam("name"));
95+
assertTrue(connection1.userData().get(TypedKey.forBoolean("boolean")));
96+
assertEquals(Integer.MAX_VALUE, connection1.userData().get(TypedKey.forInt("int")));
97+
assertEquals(Long.MAX_VALUE, connection1.userData().get(TypedKey.forLong("long")));
98+
assertEquals("Lu", connection1.userData().get(TypedKey.forString("string")));
8299
connection1.sendTextAndAwait("Hi!");
83100

101+
assertTrue(openLatch.await(5, TimeUnit.SECONDS));
102+
assertNotNull(userData.get());
103+
assertTrue(userData.get().get(TypedKey.forBoolean("boolean")));
104+
assertEquals(Integer.MAX_VALUE, userData.get().get(TypedKey.forInt("int")));
105+
assertEquals(Long.MAX_VALUE, userData.get().get(TypedKey.forLong("long")));
106+
assertEquals("Lu", userData.get().get(TypedKey.forString("string")));
107+
84108
assertTrue(messageLatch.await(5, TimeUnit.SECONDS));
85109
// Note that ordering is not guaranteed
86110
assertThat(messages.get(0)).isIn("Lu:Hello Lu!", "Lu:Hi!");

extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/programmatic/ClientEndpointProgrammaticTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package io.quarkus.websockets.next.test.client.programmatic;
22

3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertFalse;
35
import static org.junit.jupiter.api.Assertions.assertTrue;
46

57
import java.net.URI;
68
import java.util.List;
9+
import java.util.Map;
10+
import java.util.concurrent.ConcurrentHashMap;
711
import java.util.concurrent.CopyOnWriteArrayList;
812
import java.util.concurrent.CountDownLatch;
913
import java.util.concurrent.TimeUnit;
@@ -20,6 +24,8 @@
2024
import io.quarkus.websockets.next.OnClose;
2125
import io.quarkus.websockets.next.OnOpen;
2226
import io.quarkus.websockets.next.OnTextMessage;
27+
import io.quarkus.websockets.next.UserData;
28+
import io.quarkus.websockets.next.UserData.TypedKey;
2329
import io.quarkus.websockets.next.WebSocket;
2430
import io.quarkus.websockets.next.WebSocketClient;
2531
import io.quarkus.websockets.next.WebSocketClientConnection;
@@ -45,16 +51,43 @@ void testClient() throws InterruptedException {
4551
.get()
4652
.baseUri(uri)
4753
.addHeader("Foo", "Lu")
54+
.userData(TypedKey.forBoolean("boolean"), true)
55+
.userData(TypedKey.forInt("int"), Integer.MAX_VALUE)
56+
.userData(TypedKey.forLong("long"), Long.MAX_VALUE)
57+
.userData(TypedKey.forString("string"), "Lu")
4858
.connectAndAwait();
59+
assertTrue(connection1.userData().get(TypedKey.forBoolean("boolean")));
60+
assertEquals(Integer.MAX_VALUE, connection1.userData().get(TypedKey.forInt("int")));
61+
assertEquals(Long.MAX_VALUE, connection1.userData().get(TypedKey.forLong("long")));
62+
assertEquals("Lu", connection1.userData().get(TypedKey.forString("string")));
4963
connection1.sendTextAndAwait("Hi!");
5064

5165
WebSocketClientConnection connection2 = connector
5266
.get()
5367
.baseUri(uri)
5468
.addHeader("Foo", "Ma")
69+
.userData(TypedKey.forBoolean("boolean"), false)
70+
.userData(TypedKey.forInt("int"), Integer.MIN_VALUE)
71+
.userData(TypedKey.forLong("long"), Long.MIN_VALUE)
72+
.userData(TypedKey.forString("string"), "Ma")
5573
.connectAndAwait();
74+
assertFalse(connection2.userData().get(TypedKey.forBoolean("boolean")));
75+
assertEquals(Integer.MIN_VALUE, connection2.userData().get(TypedKey.forInt("int")));
76+
assertEquals(Long.MIN_VALUE, connection2.userData().get(TypedKey.forLong("long")));
77+
assertEquals("Ma", connection2.userData().get(TypedKey.forString("string")));
5678
connection2.sendTextAndAwait("Hi!");
5779

80+
assertTrue(ClientEndpoint.OPEN_LATCH.await(5, TimeUnit.SECONDS));
81+
assertTrue(ClientEndpoint.CONNECTION_USER_DATA.containsKey(connection1.id()));
82+
assertTrue(ClientEndpoint.CONNECTION_USER_DATA.get(connection1.id()).get(TypedKey.forBoolean("boolean")));
83+
assertEquals(Integer.MAX_VALUE, ClientEndpoint.CONNECTION_USER_DATA.get(connection1.id()).get(TypedKey.forInt("int")));
84+
assertEquals(Long.MAX_VALUE, ClientEndpoint.CONNECTION_USER_DATA.get(connection1.id()).get(TypedKey.forLong("long")));
85+
assertEquals("Lu", ClientEndpoint.CONNECTION_USER_DATA.get(connection1.id()).get(TypedKey.forString("string")));
86+
assertFalse(ClientEndpoint.CONNECTION_USER_DATA.get(connection2.id()).get(TypedKey.forBoolean("boolean")));
87+
assertEquals(Integer.MIN_VALUE, ClientEndpoint.CONNECTION_USER_DATA.get(connection2.id()).get(TypedKey.forInt("int")));
88+
assertEquals(Long.MIN_VALUE, ClientEndpoint.CONNECTION_USER_DATA.get(connection2.id()).get(TypedKey.forLong("long")));
89+
assertEquals("Ma", ClientEndpoint.CONNECTION_USER_DATA.get(connection2.id()).get(TypedKey.forString("string")));
90+
5891
assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS));
5992
assertTrue(ClientEndpoint.MESSAGES.contains("Lu:Hello Lu!"));
6093
assertTrue(ClientEndpoint.MESSAGES.contains("Lu:Hi!"));
@@ -92,12 +125,22 @@ void close() {
92125
@WebSocketClient(path = "/endpoint")
93126
public static class ClientEndpoint {
94127

128+
static final CountDownLatch OPEN_LATCH = new CountDownLatch(2);
129+
130+
static final Map<String, UserData> CONNECTION_USER_DATA = new ConcurrentHashMap<>();
131+
95132
static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(4);
96133

97134
static final List<String> MESSAGES = new CopyOnWriteArrayList<>();
98135

99136
static final CountDownLatch CLOSED_LATCH = new CountDownLatch(2);
100137

138+
@OnOpen
139+
void onOpen(WebSocketClientConnection connection) {
140+
CONNECTION_USER_DATA.put(connection.id(), connection.userData());
141+
OPEN_LATCH.countDown();
142+
}
143+
101144
@OnTextMessage
102145
void onMessage(String message, HandshakeRequest handshakeRequest) {
103146
MESSAGES.add(handshakeRequest.header("Foo") + ":" + message);

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/BasicWebSocketConnector.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jakarta.enterprise.inject.Instance;
1010

1111
import io.quarkus.arc.Arc;
12+
import io.quarkus.websockets.next.UserData.TypedKey;
1213
import io.smallrye.common.annotation.CheckReturnValue;
1314
import io.smallrye.mutiny.Uni;
1415
import io.vertx.core.buffer.Buffer;
@@ -114,6 +115,18 @@ default BasicWebSocketConnector baseUri(String baseUri) {
114115
*/
115116
BasicWebSocketConnector addSubprotocol(String value);
116117

118+
/**
119+
* Add a value to the connection user data.
120+
*
121+
* @param key
122+
* @param value
123+
* @param <VALUE>
124+
* @return self
125+
* @see UserData#put(TypedKey, Object)
126+
* @see WebSocketClientConnection#userData()
127+
*/
128+
<VALUE> BasicWebSocketConnector userData(TypedKey<VALUE> key, VALUE value);
129+
117130
/**
118131
* Set the execution model for callback handlers.
119132
* <p>

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnector.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jakarta.enterprise.inject.Default;
77
import jakarta.enterprise.inject.Instance;
88

9+
import io.quarkus.websockets.next.UserData.TypedKey;
910
import io.smallrye.common.annotation.CheckReturnValue;
1011
import io.smallrye.mutiny.Uni;
1112

@@ -94,6 +95,18 @@ default WebSocketConnector<CLIENT> baseUri(String baseUri) {
9495
*/
9596
WebSocketConnector<CLIENT> addSubprotocol(String value);
9697

98+
/**
99+
* Add a value to the connection user data.
100+
*
101+
* @param key
102+
* @param value
103+
* @param <VALUE>
104+
* @return self
105+
* @see UserData#put(TypedKey, Object)
106+
* @see WebSocketClientConnection#userData()
107+
*/
108+
<VALUE> WebSocketConnector<CLIENT> userData(TypedKey<VALUE> key, VALUE value);
109+
97110
/**
98111
*
99112
* @return a new {@link Uni} with a {@link WebSocketClientConnection} item

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ public void handle(AsyncResult<WebSocket> r) {
186186
codecs,
187187
pathParams,
188188
serverEndpointUri,
189-
headers, trafficLogger, null);
189+
headers, trafficLogger, userData, null);
190190
if (trafficLogger != null) {
191191
trafficLogger.connectionOpened(connection);
192192
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/UserDataImpl.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkus.websockets.next.runtime;
22

3+
import java.util.Map;
34
import java.util.concurrent.ConcurrentHashMap;
45
import java.util.concurrent.ConcurrentMap;
56

@@ -13,6 +14,10 @@ final class UserDataImpl implements UserData {
1314
this.data = new ConcurrentHashMap<>();
1415
}
1516

17+
UserDataImpl(Map<String, Object> data) {
18+
this.data = new ConcurrentHashMap<>(data);
19+
}
20+
1621
@SuppressWarnings("unchecked")
1722
@Override
1823
public <VALUE> VALUE get(TypedKey<VALUE> key) {

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ class WebSocketClientConnectionImpl extends WebSocketConnectionBase implements W
2121

2222
WebSocketClientConnectionImpl(String clientId, WebSocket webSocket, Codecs codecs,
2323
Map<String, String> pathParams, URI serverEndpointUri, Map<String, List<String>> headers,
24-
TrafficLogger trafficLogger, SendingInterceptor sendingInterceptor) {
24+
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor) {
2525
super(Map.copyOf(pathParams), codecs, new ClientHandshakeRequestImpl(serverEndpointUri, headers), trafficLogger,
26-
sendingInterceptor);
26+
new UserDataImpl(userData), sendingInterceptor);
2727
this.clientId = clientId;
2828
this.webSocket = Objects.requireNonNull(webSocket);
2929
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,19 @@ public abstract class WebSocketConnectionBase implements Connection {
3636

3737
protected final TrafficLogger trafficLogger;
3838

39-
private final UserData data;
39+
private final UserData userData;
4040

4141
private final SendingInterceptor sendingInterceptor;
4242

4343
WebSocketConnectionBase(Map<String, String> pathParams, Codecs codecs, HandshakeRequest handshakeRequest,
44-
TrafficLogger trafficLogger, SendingInterceptor sendingInterceptor) {
44+
TrafficLogger trafficLogger, UserData userData, SendingInterceptor sendingInterceptor) {
4545
this.identifier = UUID.randomUUID().toString();
4646
this.pathParams = pathParams;
4747
this.codecs = codecs;
4848
this.handshakeRequest = handshakeRequest;
4949
this.creationTime = Instant.now();
5050
this.trafficLogger = trafficLogger;
51-
this.data = new UserDataImpl();
51+
this.userData = userData;
5252
this.sendingInterceptor = sendingInterceptor;
5353
}
5454

@@ -172,7 +172,7 @@ public CloseReason closeReason() {
172172

173173
@Override
174174
public UserData userData() {
175-
return data;
175+
return userData;
176176
}
177177

178178
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class WebSocketConnectionImpl extends WebSocketConnectionBase implements WebSock
3636
ConnectionManager connectionManager, Codecs codecs, RoutingContext ctx,
3737
TrafficLogger trafficLogger, SendingInterceptor sendingInterceptor) {
3838
super(Map.copyOf(ctx.pathParams()), codecs, new HandshakeRequestImpl(webSocket, ctx), trafficLogger,
39-
sendingInterceptor);
39+
new UserDataImpl(), sendingInterceptor);
4040
this.generatedEndpointClass = generatedEndpointClass;
4141
this.endpointId = endpointClass;
4242
this.webSocket = Objects.requireNonNull(webSocket);

0 commit comments

Comments
 (0)