Skip to content

Commit 04a704c

Browse files
[Java] Observe accessTokenProvider on error (#24344)
1 parent 9b5999f commit 04a704c

File tree

4 files changed

+190
-31
lines changed

4 files changed

+190
-31
lines changed

src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ public Completable start() {
358358
this.localHeaders.put("Authorization", "Bearer " + token);
359359
}
360360
tokenCompletable.onComplete();
361+
}, error -> {
362+
tokenCompletable.onError(error);
361363
});
362364

363365
stopError = null;

src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ class LongPollingTransport implements Transport {
2525
private final HttpClient pollingClient;
2626
private final Map<String, String> headers;
2727
private static final int POLL_TIMEOUT = 100*1000;
28+
private final Single<String> accessTokenProvider;
2829
private volatile Boolean active = false;
2930
private String pollUrl;
3031
private String closeError;
31-
private Single<String> accessTokenProvider;
3232
private CompletableSubject receiveLoop = CompletableSubject.create();
3333
private ExecutorService threadPool;
3434
private ExecutorService onReceiveThread;
@@ -41,21 +41,19 @@ public LongPollingTransport(Map<String, String> headers, HttpClient client, Sing
4141
this.client = client;
4242
this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT);
4343
this.accessTokenProvider = accessTokenProvider;
44-
this.onReceiveThread = Executors.newSingleThreadExecutor();
4544
}
4645

4746
//Package private active accessor for testing.
4847
boolean isActive() {
4948
return this.active;
5049
}
5150

52-
private Single updateHeaderToken() {
53-
return this.accessTokenProvider.flatMap((token) -> {
51+
private Completable updateHeaderToken() {
52+
return this.accessTokenProvider.doOnSuccess((token) -> {
5453
if (!token.isEmpty()) {
5554
this.headers.put("Authorization", "Bearer " + token);
5655
}
57-
return Single.just("");
58-
});
56+
}).ignoreElement();
5957
}
6058

6159
@Override
@@ -65,7 +63,7 @@ public Completable start(String url) {
6563
this.url = url;
6664
pollUrl = url + "&_=" + System.currentTimeMillis();
6765
logger.debug("Polling {}.", pollUrl);
68-
return this.updateHeaderToken().flatMapCompletable((r) -> {
66+
return this.updateHeaderToken().andThen(Completable.defer(() -> {
6967
HttpRequest request = new HttpRequest();
7068
request.addHeaders(headers);
7169
return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@@ -77,18 +75,26 @@ public Completable start(String url) {
7775
this.active = true;
7876
}
7977
this.threadPool = Executors.newCachedThreadPool();
80-
threadPool.execute(() -> poll(url).subscribeWith(receiveLoop));
78+
threadPool.execute(() -> {
79+
this.onReceiveThread = Executors.newSingleThreadExecutor();
80+
receiveLoop.subscribe(() -> {
81+
this.stop().onErrorComplete().subscribe();
82+
}, e -> {
83+
this.stop().onErrorComplete().subscribe();
84+
});
85+
poll(url).subscribeWith(receiveLoop);
86+
});
8187

8288
return Completable.complete();
8389
});
84-
});
90+
}));
8591
}
8692

8793
private Completable poll(String url) {
8894
if (this.active) {
8995
pollUrl = url + "&_=" + System.currentTimeMillis();
9096
logger.debug("Polling {}.", pollUrl);
91-
return this.updateHeaderToken().flatMapCompletable((x) -> {
97+
return this.updateHeaderToken().andThen(Completable.defer(() -> {
9298
HttpRequest request = new HttpRequest();
9399
request.addHeaders(headers);
94100
Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@@ -111,13 +117,10 @@ private Completable poll(String url) {
111117
});
112118

113119
return pollingCompletable;
114-
});
120+
}));
115121
} else {
116122
logger.debug("Long Polling transport polling complete.");
117123
receiveLoop.onComplete();
118-
if (!stopCalled.get()) {
119-
return this.stop();
120-
}
121124
return Completable.complete();
122125
}
123126
}
@@ -127,11 +130,11 @@ public Completable send(ByteBuffer message) {
127130
if (!this.active) {
128131
return Completable.error(new Exception("Cannot send unless the transport is active."));
129132
}
130-
return this.updateHeaderToken().flatMapCompletable((x) -> {
133+
return this.updateHeaderToken().andThen(Completable.defer(() -> {
131134
HttpRequest request = new HttpRequest();
132135
request.addHeaders(headers);
133-
return Completable.fromSingle(this.client.post(url, message, request));
134-
});
136+
return this.client.post(url, message, request).ignoreElement();
137+
}));
135138
}
136139

137140
@Override
@@ -152,23 +155,31 @@ public void setOnClose(TransportOnClosedCallback onCloseCallback) {
152155

153156
@Override
154157
public Completable stop() {
155-
if (!stopCalled.get()) {
156-
this.stopCalled.set(true);
158+
if (stopCalled.compareAndSet(false, true)) {
157159
this.active = false;
158-
return this.updateHeaderToken().flatMapCompletable((x) -> {
160+
return this.updateHeaderToken().andThen(Completable.defer(() -> {
159161
HttpRequest request = new HttpRequest();
160162
request.addHeaders(headers);
161-
this.pollingClient.delete(this.url, request);
162-
CompletableSubject stopCompletableSubject = CompletableSubject.create();
163-
return this.receiveLoop.andThen(Completable.defer(() -> {
164-
logger.info("LongPolling transport stopped.");
165-
this.onReceiveThread.shutdown();
166-
this.threadPool.shutdown();
167-
this.onClose.invoke(this.closeError);
168-
return Completable.complete();
169-
})).subscribeWith(stopCompletableSubject);
163+
return this.pollingClient.delete(this.url, request).ignoreElement()
164+
.andThen(receiveLoop)
165+
.doOnComplete(() -> {
166+
cleanup(this.closeError);
167+
});
168+
})).doOnError(e -> {
169+
cleanup(e.getMessage());
170170
});
171171
}
172172
return Completable.complete();
173173
}
174+
175+
private void cleanup(String error) {
176+
logger.info("LongPolling transport stopped.");
177+
if (this.onReceiveThread != null) {
178+
this.onReceiveThread.shutdown();
179+
}
180+
if (this.threadPool != null) {
181+
this.threadPool.shutdown();
182+
}
183+
this.onClose.invoke(error);
184+
}
174185
}

src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2913,7 +2913,8 @@ public void TransportAllUsesLongPollingWhenServerOnlySupportLongPolling() {
29132913
}
29142914
assertTrue(close.blockingAwait(5, TimeUnit.SECONDS));
29152915
return Single.just(new HttpResponse(204, "", TestUtils.emptyByteBuffer));
2916-
});
2916+
})
2917+
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
29172918

29182919
HubConnection hubConnection = HubConnectionBuilder
29192920
.create("http://example.com")
@@ -2969,6 +2970,135 @@ public void ClientThatSelectsLongPollingThrowsWhenLongPollingIsNotAvailable() {
29692970
assertEquals(exception.getMessage(), "There were no compatible transports on the server.");
29702971
}
29712972

2973+
@Test
2974+
public void LongPollingTransportAccessTokenProviderThrowsOnInitialPoll() {
2975+
TestHttpClient client = new TestHttpClient()
2976+
.on("POST", (req) -> {
2977+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
2978+
})
2979+
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
2980+
(req) -> Single.just(new HttpResponse(200, "",
2981+
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
2982+
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
2983+
.on("GET", (req) -> {
2984+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
2985+
});
2986+
2987+
AtomicInteger accessTokenCount = new AtomicInteger(0);
2988+
HubConnection hubConnection = HubConnectionBuilder
2989+
.create("http://example.com")
2990+
.withTransport(TransportEnum.LONG_POLLING)
2991+
.withHttpClient(client)
2992+
.withAccessTokenProvider(Single.defer(() -> {
2993+
if (accessTokenCount.getAndIncrement() < 1) {
2994+
return Single.just("");
2995+
}
2996+
return Single.error(new RuntimeException("Error from accessTokenProvider"));
2997+
}))
2998+
.build();
2999+
3000+
try {
3001+
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
3002+
assertTrue(false);
3003+
} catch (RuntimeException ex) {
3004+
assertEquals("Error from accessTokenProvider", ex.getMessage());
3005+
}
3006+
}
3007+
3008+
@Test
3009+
public void LongPollingTransportAccessTokenProviderThrowsAfterHandshakeClosesConnection() {
3010+
AtomicInteger requestCount = new AtomicInteger(0);
3011+
CompletableSubject blockGet = CompletableSubject.create();
3012+
TestHttpClient client = new TestHttpClient()
3013+
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
3014+
(req) -> Single.just(new HttpResponse(200, "",
3015+
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
3016+
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
3017+
.on("GET", (req) -> {
3018+
if (requestCount.getAndIncrement() > 1) {
3019+
blockGet.blockingAwait();
3020+
}
3021+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
3022+
})
3023+
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
3024+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
3025+
});
3026+
3027+
AtomicInteger accessTokenCount = new AtomicInteger(0);
3028+
HubConnection hubConnection = HubConnectionBuilder
3029+
.create("http://example.com")
3030+
.withTransport(TransportEnum.LONG_POLLING)
3031+
.withHttpClient(client)
3032+
.withAccessTokenProvider(Single.defer(() -> {
3033+
if (accessTokenCount.getAndIncrement() < 5) {
3034+
return Single.just("");
3035+
}
3036+
return Single.error(new RuntimeException("Error from accessTokenProvider"));
3037+
}))
3038+
.build();
3039+
3040+
CompletableSubject closed = CompletableSubject.create();
3041+
hubConnection.onClosed((e) -> {
3042+
closed.onComplete();
3043+
});
3044+
3045+
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
3046+
blockGet.onComplete();
3047+
3048+
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
3049+
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
3050+
}
3051+
3052+
@Test
3053+
public void LongPollingTransportAccessTokenProviderThrowsDuringStop() {
3054+
AtomicInteger requestCount = new AtomicInteger(0);
3055+
CompletableSubject blockGet = CompletableSubject.create();
3056+
TestHttpClient client = new TestHttpClient()
3057+
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
3058+
(req) -> Single.just(new HttpResponse(200, "",
3059+
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
3060+
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
3061+
.on("GET", (req) -> {
3062+
if (requestCount.getAndIncrement() > 1) {
3063+
blockGet.blockingAwait();
3064+
}
3065+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
3066+
})
3067+
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
3068+
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
3069+
});
3070+
3071+
AtomicInteger accessTokenCount = new AtomicInteger(0);
3072+
HubConnection hubConnection = HubConnectionBuilder
3073+
.create("http://example.com")
3074+
.withTransport(TransportEnum.LONG_POLLING)
3075+
.withHttpClient(client)
3076+
.withAccessTokenProvider(Single.defer(() -> {
3077+
if (accessTokenCount.getAndIncrement() < 5) {
3078+
return Single.just("");
3079+
}
3080+
return Single.error(new RuntimeException("Error from accessTokenProvider"));
3081+
}))
3082+
.build();
3083+
3084+
CompletableSubject closed = CompletableSubject.create();
3085+
hubConnection.onClosed((e) -> {
3086+
closed.onComplete();
3087+
});
3088+
3089+
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
3090+
3091+
try {
3092+
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
3093+
assertTrue(false);
3094+
} catch (Exception ex) {
3095+
assertEquals("Error from accessTokenProvider", ex.getMessage());
3096+
}
3097+
blockGet.onComplete();
3098+
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
3099+
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
3100+
}
3101+
29723102
@Test
29733103
public void receivingServerSentEventsTransportFromNegotiateFails() {
29743104
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1",
@@ -3265,6 +3395,21 @@ public void authorizationHeaderFromNegotiateGetsSetToNewValue() {
32653395
assertEquals("Bearer secondRedirectToken", token.get());
32663396
}
32673397

3398+
@Test
3399+
public void ErrorInAccessTokenProviderThrowsFromStart() {
3400+
HubConnection hubConnection = HubConnectionBuilder
3401+
.create("http://example.com")
3402+
.withAccessTokenProvider(Single.defer(() -> Single.error(new RuntimeException("Error from accessTokenProvider"))))
3403+
.build();
3404+
3405+
try {
3406+
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
3407+
assertTrue(false);
3408+
} catch (RuntimeException ex) {
3409+
assertEquals("Error from accessTokenProvider", ex.getMessage());
3410+
}
3411+
}
3412+
32683413
@Test
32693414
public void connectionTimesOutIfServerDoesNotSendMessage() {
32703415
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");

src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ public void LongPollingFailsWhenReceivingUnexpectedErrorCode() {
8686
return Single.just(new HttpResponse(200, "", TestUtils.emptyByteBuffer));
8787
}
8888
return Single.just(new HttpResponse(999, "", TestUtils.emptyByteBuffer));
89-
});
89+
})
90+
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
9091

9192
Map<String, String> headers = new HashMap<>();
9293
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));

0 commit comments

Comments
 (0)