diff --git a/src/main/java/org/gridsuite/geodata/server/GeoDataController.java b/src/main/java/org/gridsuite/geodata/server/GeoDataController.java index 36373ae4..3e5bea55 100644 --- a/src/main/java/org/gridsuite/geodata/server/GeoDataController.java +++ b/src/main/java/org/gridsuite/geodata/server/GeoDataController.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; /** @@ -53,7 +54,7 @@ private static Set toCountrySet(@RequestParam(required = false) List> getSubstations(@Parameter(description = "Network UUID") @RequestParam UUID networkUuid, + public CompletableFuture>> getSubstations(@Parameter(description = "Network UUID") @RequestParam UUID networkUuid, @Parameter(description = "Variant Id") @RequestParam(name = "variantId", required = false) String variantId, @Parameter(description = "Countries") @RequestParam(name = "country", required = false) List countries, @RequestBody(required = false) List substationIds) { @@ -62,14 +63,14 @@ public ResponseEntity> getSubstations(@Parameter(descrip if (variantId != null) { network.getVariantManager().setWorkingVariant(variantId); } - List substations = geoDataService.getSubstationsData(network, countrySet, substationIds); - return ResponseEntity.ok().body(substations); + return geoDataService.getSubstationsData(network, countrySet, substationIds).thenApply( + substations -> ResponseEntity.ok().body(substations)); } @PostMapping(value = "/lines/infos", consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_VALUE) @Operation(summary = "Get lines geographical data") @ApiResponses(value = {@ApiResponse(responseCode = "200", description = "Lines geographical data")}) - public ResponseEntity> getLines(@Parameter(description = "Network UUID")@RequestParam UUID networkUuid, + public CompletableFuture>> getLines(@Parameter(description = "Network UUID")@RequestParam UUID networkUuid, @Parameter(description = "Variant Id") @RequestParam(name = "variantId", required = false) String variantId, @Parameter(description = "Countries") @RequestParam(name = "country", required = false) List countries, @RequestBody(required = false) List lineIds) { @@ -78,8 +79,8 @@ public ResponseEntity> getLines(@Parameter(description = "Netw if (variantId != null) { network.getVariantManager().setWorkingVariant(variantId); } - List lines = geoDataService.getLinesData(network, countrySet, lineIds); - return ResponseEntity.ok().body(lines); + return geoDataService.getLinesData(network, countrySet, lineIds).thenApply( + lines -> ResponseEntity.ok().body(lines)); } @PostMapping(value = "/substations") diff --git a/src/main/java/org/gridsuite/geodata/server/GeoDataService.java b/src/main/java/org/gridsuite/geodata/server/GeoDataService.java index b7ca4460..ebd8c998 100644 --- a/src/main/java/org/gridsuite/geodata/server/GeoDataService.java +++ b/src/main/java/org/gridsuite/geodata/server/GeoDataService.java @@ -10,7 +10,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Streams; -import com.powsybl.commons.exceptions.UncheckedInterruptedException; import com.powsybl.iidm.network.*; import com.powsybl.iidm.network.extensions.Coordinate; import com.powsybl.iidm.network.extensions.SubstationPosition; @@ -29,7 +28,6 @@ import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; import java.util.*; import java.util.Map.Entry; @@ -546,48 +544,38 @@ private Pair getSubstations(Identifiable identifiable }; } - @Transactional(readOnly = true) - public List getSubstationsData(Network network, Set countrySet, List substationIds) { - CompletableFuture> substationGeoDataFuture = geoDataExecutionService.supplyAsync(() -> { - if (substationIds != null) { - if (!countrySet.isEmpty()) { - LOGGER.warn("Countries will not be taken into account to filter substation position."); + public CompletableFuture> getSubstationsData(Network network, Set countrySet, List substationIds) { + return geoDataExecutionService.supplyAsync(() -> { + try { + if (substationIds != null) { + if (!countrySet.isEmpty()) { + LOGGER.warn("Countries will not be taken into account to filter substation position."); + } + return getSubstationsByIds(network, new HashSet<>(substationIds)); + } else { + return getSubstationsByCountries(network, countrySet); } - return getSubstationsByIds(network, new HashSet<>(substationIds)); - } else { - return getSubstationsByCountries(network, countrySet); + } catch (Exception e) { + throw new GeoDataException(FAILED_SUBSTATIONS_LOADING, e); } }); - try { - return substationGeoDataFuture.get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new UncheckedInterruptedException(e); - } catch (Exception e) { - throw new GeoDataException(FAILED_SUBSTATIONS_LOADING, e); - } } - @Transactional(readOnly = true) - public List getLinesData(Network network, Set countrySet, List lineIds) { - CompletableFuture> lineGeoDataFuture = geoDataExecutionService.supplyAsync(() -> { - if (lineIds != null) { - if (!countrySet.isEmpty()) { - LOGGER.warn("Countries will not be taken into account to filter line position."); + public CompletableFuture> getLinesData(Network network, Set countrySet, List lineIds) { + return geoDataExecutionService.supplyAsync(() -> { + try { + if (lineIds != null) { + if (!countrySet.isEmpty()) { + LOGGER.warn("Countries will not be taken into account to filter line position."); + } + return getLinesByIds(network, new HashSet<>(lineIds)); + } else { + return getLinesByCountries(network, countrySet); } - return getLinesByIds(network, new HashSet<>(lineIds)); - } else { - return getLinesByCountries(network, countrySet); + } catch (Exception e) { + throw new GeoDataException(FAILED_LINES_LOADING, e); } }); - try { - return lineGeoDataFuture.get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new UncheckedInterruptedException(e); - } catch (Exception e) { - throw new GeoDataException(FAILED_LINES_LOADING, e); - } } List getLinesByIds(Network network, Set linesIds) { diff --git a/src/test/java/org/gridsuite/geodata/server/GeoDataControllerTest.java b/src/test/java/org/gridsuite/geodata/server/GeoDataControllerTest.java index e8754f3c..323adfd5 100644 --- a/src/test/java/org/gridsuite/geodata/server/GeoDataControllerTest.java +++ b/src/test/java/org/gridsuite/geodata/server/GeoDataControllerTest.java @@ -24,6 +24,7 @@ import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -37,7 +38,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.springframework.http.MediaType.APPLICATION_JSON; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; /** @@ -84,14 +85,28 @@ void test() throws Exception { given(service.getNetwork(networkUuid, PreloadingStrategy.NONE)).willReturn(testNetwork); given(service.getNetwork(networkUuid, PreloadingStrategy.COLLECTION)).willReturn(testNetwork); - mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid) + // Just to hold the async state that we need to pass immediately back to mockmvc + // as per https://docs.spring.io/spring-framework/reference/testing/mockmvc/hamcrest/async-requests.html + // to mimic servlet 3.0+ AsyncContext not done automatically by TestDispatcherServlet (unlike the real servlet + // container in production) + // Note: if the controller throws before returning the completablefuture, the request is not async and we don't + // need this and we must not call asyncStarted() and asyncDispatch() + MvcResult mvcResult; + + mvcResult = mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid) .contentType(APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); - mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) + mvcResult = mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) .contentType(APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); @@ -101,14 +116,20 @@ void test() throws Exception { .andExpect(content().string("Variant '" + WRONG_VARIANT_ID + "' not found")) .andExpect(status().isInternalServerError()); - mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid) + mvcResult = mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid) .contentType(APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); - mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) + mvcResult = mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) .contentType(APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); @@ -152,30 +173,42 @@ void test() throws Exception { .content(toString(GEO_DATA_LINES))) .andExpect(status().isOk()); - mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID + "&country=" + Country.FR) + mvcResult = mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID + "&country=" + Country.FR) .contentType(APPLICATION_JSON) .content("[\"P1\", \"P2\"]")) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); - mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) + mvcResult = mvc.perform(post("/" + VERSION + "/substations/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) .contentType(APPLICATION_JSON) .content("[\"P1\", \"P2\"]")) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); - mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID + "&country=" + Country.FR) + mvcResult = mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID + "&country=" + Country.FR) .contentType(APPLICATION_JSON) .content("[\"NHV1_NHV2_2\", \"NHV1_NHV2_1\"]")) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); - mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) + mvcResult = mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid + "&variantId=" + VARIANT_ID) .contentType(APPLICATION_JSON) .content("[\"NHV1_NHV2_2\", \"NHV1_NHV2_1\"]")) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(APPLICATION_JSON)) .andExpect(jsonPath("$", hasSize(0))); @@ -189,8 +222,11 @@ void testGetLinesError() throws Exception { given(service.getNetwork(networkUuid, PreloadingStrategy.COLLECTION)).willReturn(testNetwork); given(lineRepository.findAllById(any())).willThrow(new GeoDataException(GeoDataException.Type.PARSING_ERROR, new RuntimeException("Error parsing"))); - mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid) + MvcResult mvcResult = mvc.perform(post("/" + VERSION + "/lines/infos?networkUuid=" + networkUuid) .contentType(APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andReturn(); + mvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isInternalServerError()); } }