Skip to content

Commit eaddb9b

Browse files
core, NPE protect on Context in PollingStrategy (Azure#28507)
* core, NPE protect on Context in PollingStrategy * add test case for Context passed to PollingStragety
1 parent cef79dc commit eaddb9b

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

sdk/core/azure-core/src/main/java/com/azure/core/util/polling/LocationPollingStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public LocationPollingStrategy(HttpPipeline httpPipeline, ObjectSerializer seria
7575
public LocationPollingStrategy(HttpPipeline httpPipeline, ObjectSerializer serializer, Context context) {
7676
this.httpPipeline = Objects.requireNonNull(httpPipeline, "'httpPipeline' cannot be null");
7777
this.serializer = (serializer == null) ? DEFAULT_SERIALIZER : serializer;
78-
this.context = context;
78+
this.context = context == null ? Context.NONE : context;
7979
}
8080

8181
@Override

sdk/core/azure-core/src/main/java/com/azure/core/util/polling/OperationResourcePollingStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public OperationResourcePollingStrategy(HttpPipeline httpPipeline, ObjectSeriali
7474
this.serializer = serializer != null ? serializer : new DefaultJsonSerializer();
7575
this.operationLocationHeaderName = operationLocationHeaderName != null ? operationLocationHeaderName
7676
: DEFAULT_OPERATION_LOCATION_HEADER;
77-
this.context = context;
77+
this.context = context == null ? Context.NONE : context;
7878
}
7979

8080

sdk/core/azure-core/src/test/java/com/azure/core/util/polling/PollingStrategyTests.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import com.azure.core.http.MockHttpResponse;
1212
import com.azure.core.http.rest.Response;
1313
import com.azure.core.http.rest.SimpleResponse;
14+
import com.azure.core.util.Context;
1415
import com.azure.core.util.serializer.TypeReference;
1516
import org.junit.jupiter.api.AfterEach;
17+
import org.junit.jupiter.api.Assertions;
1618
import org.junit.jupiter.api.BeforeEach;
1719
import org.junit.jupiter.api.Test;
20+
import org.mockito.ArgumentCaptor;
1821
import org.mockito.Mock;
1922
import org.mockito.Mockito;
2023
import org.mockito.MockitoAnnotations;
@@ -302,6 +305,68 @@ public void locationPollingStrategySucceedsOnPollWithPostLocationHeader() {
302305
assertEquals(1, activationCallCount[0]);
303306
}
304307

308+
@Test
309+
public void pollingStrategyPassContextToHttpClient() {
310+
int[] activationCallCount = new int[1];
311+
activationCallCount[0] = 0;
312+
String mockPollUrl = "http://localhost/poll";
313+
String finalResultUrl = "http://localhost/final";
314+
when(activationOperation.get()).thenReturn(Mono.defer(() -> {
315+
activationCallCount[0]++;
316+
SimpleResponse<PollResult> response = new SimpleResponse<>(
317+
new HttpRequest(HttpMethod.POST, "http://localhost"),
318+
200,
319+
new HttpHeaders().set("Location", mockPollUrl),
320+
new PollResult("InProgress"));
321+
return Mono.just(response);
322+
}));
323+
HttpRequest pollRequest = new HttpRequest(HttpMethod.GET, mockPollUrl);
324+
ArgumentCaptor<Context> contextArgument = ArgumentCaptor.forClass(Context.class);
325+
when(httpClient.send(any(), contextArgument.capture()))
326+
.thenAnswer(iom -> {
327+
HttpRequest req = iom.getArgument(0);
328+
if (mockPollUrl.equals(req.getUrl().toString())) {
329+
return Mono.just(new MockHttpResponse(pollRequest, 200,
330+
new HttpHeaders().set("Location", finalResultUrl),
331+
new PollResult("Succeeded")));
332+
} else if (finalResultUrl.equals(req.getUrl().toString())) {
333+
return Mono.just(new MockHttpResponse(pollRequest, 200, new HttpHeaders(),
334+
new PollResult("final-state")));
335+
} else {
336+
return Mono.error(new IllegalArgumentException("Unknown request URL " + req.getUrl()));
337+
}
338+
});
339+
340+
// PollingStrategy with context = Context.NONE
341+
PollerFlux<PollResult, PollResult> pollerFlux = PollerFlux.create(
342+
Duration.ofSeconds(1),
343+
() -> activationOperation.get(),
344+
new DefaultPollingStrategy<>(new HttpPipelineBuilder().httpClient(httpClient).build(), null, null),
345+
new TypeReference<PollResult>() { }, new TypeReference<PollResult>() { });
346+
347+
StepVerifier.create(pollerFlux.map(AsyncPollResponse::getStatus))
348+
.expectSubscription()
349+
.expectNext(LongRunningOperationStatus.SUCCESSFULLY_COMPLETED)
350+
.verifyComplete();
351+
Assertions.assertEquals(Context.NONE, contextArgument.getValue());
352+
353+
// PollingStrategy with context
354+
final Context context = new Context("key", "value");
355+
pollerFlux = PollerFlux.create(
356+
Duration.ofSeconds(1),
357+
() -> activationOperation.get(),
358+
new DefaultPollingStrategy<>(new HttpPipelineBuilder().httpClient(httpClient).build(), null, context),
359+
new TypeReference<PollResult>() { }, new TypeReference<PollResult>() { });
360+
361+
StepVerifier.create(pollerFlux.map(AsyncPollResponse::getStatus))
362+
.expectSubscription()
363+
.expectNext(LongRunningOperationStatus.SUCCESSFULLY_COMPLETED)
364+
.verifyComplete();
365+
Assertions.assertEquals(context, contextArgument.getValue());
366+
367+
assertEquals(2, activationCallCount[0]);
368+
}
369+
305370
public static class PollResult {
306371
private String status;
307372
private String resourceLocation;

0 commit comments

Comments
 (0)