diff --git a/pom.xml b/pom.xml index 6baf242..fc21d1d 100644 --- a/pom.xml +++ b/pom.xml @@ -43,11 +43,15 @@ 2.23.0 + 1.28.1 + 1.6.0 4.4 2.2 + 3.2.0 + gridsuite org.gridsuite:computation @@ -129,6 +133,11 @@ org.springframework.cloud spring-cloud-stream + + io.awspring.cloud + spring-cloud-aws-starter-s3 + ${spring-cloud-aws.version} + @@ -145,6 +154,12 @@ + + com.powsybl + powsybl-ws-commons + ${powsybl-ws-commons.version} + true + com.powsybl powsybl-network-store-client @@ -224,5 +239,12 @@ spring-boot-test-autoconfigure test + + + + com.google.jimfs + jimfs + test + diff --git a/src/main/java/org/gridsuite/computation/s3/ComputationS3Service.java b/src/main/java/org/gridsuite/computation/s3/ComputationS3Service.java new file mode 100644 index 0000000..92882e4 --- /dev/null +++ b/src/main/java/org/gridsuite/computation/s3/ComputationS3Service.java @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2025, RTE (http://www.rte-france.com) + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.gridsuite.computation.s3; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; + +/** + * @author Thang PHAM + */ +public class ComputationS3Service { + + public static final String S3_DELIMITER = "/"; + public static final String S3_SERVICE_NOT_AVAILABLE_MESSAGE = "S3 service not available"; + + public static final String METADATA_FILE_NAME = "file-name"; + + private final S3Client s3Client; + + private final String bucketName; + + public ComputationS3Service(S3Client s3Client, String bucketName) { + this.s3Client = s3Client; + this.bucketName = bucketName; + } + + public void uploadFile(Path filePath, String s3Key, String fileName, Integer expireAfterMinutes) throws IOException { + try { + PutObjectRequest putRequest = PutObjectRequest.builder() + .bucket(bucketName) + .key(s3Key) + .metadata(Map.of(METADATA_FILE_NAME, fileName)) + .tagging(expireAfterMinutes != null ? "expire-after-minutes=" + expireAfterMinutes : null) + .build(); + s3Client.putObject(putRequest, RequestBody.fromFile(filePath)); + } catch (SdkException e) { + throw new IOException("Error occurred while uploading file to S3: " + e.getMessage()); + } + } + + public S3InputStreamInfos downloadFile(String s3Key) throws IOException { + try { + GetObjectRequest getRequest = GetObjectRequest.builder() + .bucket(bucketName) + .key(s3Key) + .build(); + ResponseInputStream inputStream = s3Client.getObject(getRequest); + return S3InputStreamInfos.builder() + .inputStream(inputStream) + .fileName(inputStream.response().metadata().get(METADATA_FILE_NAME)) + .fileLength(inputStream.response().contentLength()) + .build(); + } catch (SdkException e) { + throw new IOException("Error occurred while downloading file from S3: " + e.getMessage()); + } + } +} diff --git a/src/main/java/org/gridsuite/computation/s3/S3AutoConfiguration.java b/src/main/java/org/gridsuite/computation/s3/S3AutoConfiguration.java new file mode 100644 index 0000000..832f819 --- /dev/null +++ b/src/main/java/org/gridsuite/computation/s3/S3AutoConfiguration.java @@ -0,0 +1,33 @@ +/** + * Copyright (c) 2025, RTE (http://www.rte-france.com) + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.gridsuite.computation.s3; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import software.amazon.awssdk.services.s3.S3Client; + +/** + * @author Thang PHAM + */ +@AutoConfiguration +@ConditionalOnProperty(name = "computation.s3.enabled", havingValue = "true") +public class S3AutoConfiguration { + private static final Logger LOGGER = LoggerFactory.getLogger(S3AutoConfiguration.class); + @Value("${spring.cloud.aws.bucket:ws-bucket}") + private String bucketName; + + @Bean + public ComputationS3Service s3Service(S3Client s3Client) { + LOGGER.info("Configuring ComputationS3Service with bucket: {}", bucketName); + return new ComputationS3Service(s3Client, bucketName); + } +} diff --git a/src/main/java/org/gridsuite/computation/s3/S3InputStreamInfos.java b/src/main/java/org/gridsuite/computation/s3/S3InputStreamInfos.java new file mode 100644 index 0000000..6af4142 --- /dev/null +++ b/src/main/java/org/gridsuite/computation/s3/S3InputStreamInfos.java @@ -0,0 +1,24 @@ +/** + * Copyright (c) 2025, RTE (http://www.rte-france.com) + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.gridsuite.computation.s3; + +import lombok.Builder; +import lombok.Getter; + +import java.io.InputStream; + +/** + * @author Thang PHAM + */ +@Builder +@Getter +public class S3InputStreamInfos { + InputStream inputStream; + String fileName; + Long fileLength; +} diff --git a/src/main/java/org/gridsuite/computation/service/AbstractComputationResultService.java b/src/main/java/org/gridsuite/computation/service/AbstractComputationResultService.java index c5424d6..2dc37d0 100644 --- a/src/main/java/org/gridsuite/computation/service/AbstractComputationResultService.java +++ b/src/main/java/org/gridsuite/computation/service/AbstractComputationResultService.java @@ -22,4 +22,14 @@ public abstract class AbstractComputationResultService { public abstract void deleteAll(); public abstract S findStatus(UUID resultUuid); + + // --- Must implement these following methods if a computation server supports s3 storage --- // + public void saveDebugFileLocation(UUID resultUuid, String debugFilePath) { + // to override by subclasses + } + + public String findDebugFileLocation(UUID resultUuid) { + // to override by subclasses + return null; + } } diff --git a/src/main/java/org/gridsuite/computation/service/AbstractComputationRunContext.java b/src/main/java/org/gridsuite/computation/service/AbstractComputationRunContext.java index 8cb4ca9..53e7afc 100644 --- a/src/main/java/org/gridsuite/computation/service/AbstractComputationRunContext.java +++ b/src/main/java/org/gridsuite/computation/service/AbstractComputationRunContext.java @@ -12,6 +12,7 @@ import lombok.Setter; import org.gridsuite.computation.dto.ReportInfos; +import java.nio.file.Path; import java.util.UUID; /** @@ -30,9 +31,16 @@ public abstract class AbstractComputationRunContext

{ private P parameters; private ReportNode reportNode; private Network network; + private Boolean debug; + private Path debugDir; protected AbstractComputationRunContext(UUID networkUuid, String variantId, String receiver, ReportInfos reportInfos, String userId, String provider, P parameters) { + this(networkUuid, variantId, receiver, reportInfos, userId, provider, parameters, null); + } + + protected AbstractComputationRunContext(UUID networkUuid, String variantId, String receiver, ReportInfos reportInfos, + String userId, String provider, P parameters, Boolean debug) { this.networkUuid = networkUuid; this.variantId = variantId; this.receiver = receiver; @@ -42,5 +50,6 @@ protected AbstractComputationRunContext(UUID networkUuid, String variantId, Stri this.parameters = parameters; this.reportNode = ReportNode.NO_OP; this.network = null; + this.debug = debug; } } diff --git a/src/main/java/org/gridsuite/computation/service/AbstractComputationService.java b/src/main/java/org/gridsuite/computation/service/AbstractComputationService.java index 824f1ee..4684339 100644 --- a/src/main/java/org/gridsuite/computation/service/AbstractComputationService.java +++ b/src/main/java/org/gridsuite/computation/service/AbstractComputationService.java @@ -7,13 +7,26 @@ package org.gridsuite.computation.service; import com.fasterxml.jackson.databind.ObjectMapper; +import com.powsybl.commons.PowsyblException; import lombok.Getter; +import org.gridsuite.computation.s3.ComputationS3Service; +import org.gridsuite.computation.s3.S3InputStreamInfos; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.util.CollectionUtils; +import java.io.IOException; +import java.io.InputStream; import java.util.List; import java.util.Objects; import java.util.UUID; +import static org.gridsuite.computation.s3.ComputationS3Service.S3_SERVICE_NOT_AVAILABLE_MESSAGE; + /** * @author Mathieu Deharbe * @param run context specific to a computation, including parameters @@ -26,6 +39,7 @@ public abstract class AbstractComputationService resultUuids, S status) { public S getStatus(UUID resultUuid) { return resultService.findStatus(resultUuid); } + + public ResponseEntity downloadDebugFile(UUID resultUuid) { + if (computationS3Service == null) { + throw new PowsyblException(S3_SERVICE_NOT_AVAILABLE_MESSAGE); + } + + String s3Key = resultService.findDebugFileLocation(resultUuid); + if (s3Key == null) { + return ResponseEntity.notFound().build(); + } + + try { + S3InputStreamInfos s3InputStreamInfos = computationS3Service.downloadFile(s3Key); + InputStream inputStream = s3InputStreamInfos.getInputStream(); + String fileName = s3InputStreamInfos.getFileName(); + Long fileLength = s3InputStreamInfos.getFileLength(); + + // build header + HttpHeaders headers = new HttpHeaders(); + headers.setContentDisposition(ContentDisposition.builder("attachment").filename(fileName).build()); + headers.setContentLength(fileLength); + + // wrap s3 input stream + InputStreamResource resource = new InputStreamResource(inputStream); + return ResponseEntity.ok() + .headers(headers) + .contentType(MediaType.APPLICATION_OCTET_STREAM) + .body(resource); + } catch (IOException e) { + return ResponseEntity.notFound().build(); + } + } + } diff --git a/src/main/java/org/gridsuite/computation/service/AbstractResultContext.java b/src/main/java/org/gridsuite/computation/service/AbstractResultContext.java index b80728d..3201a59 100644 --- a/src/main/java/org/gridsuite/computation/service/AbstractResultContext.java +++ b/src/main/java/org/gridsuite/computation/service/AbstractResultContext.java @@ -67,6 +67,7 @@ public Message toMessage(ObjectMapper objectMapper) { .setHeader(REPORT_UUID_HEADER, runContext.getReportInfos().reportUuid() != null ? runContext.getReportInfos().reportUuid().toString() : null) .setHeader(REPORTER_ID_HEADER, runContext.getReportInfos().reporterId()) .setHeader(REPORT_TYPE_HEADER, runContext.getReportInfos().computationType()) + .setHeader(HEADER_DEBUG, runContext.getDebug()) .copyHeaders(getSpecificMsgHeaders(objectMapper)) .build(); } diff --git a/src/main/java/org/gridsuite/computation/service/AbstractWorkerService.java b/src/main/java/org/gridsuite/computation/service/AbstractWorkerService.java index 45d5c99..2becdc9 100644 --- a/src/main/java/org/gridsuite/computation/service/AbstractWorkerService.java +++ b/src/main/java/org/gridsuite/computation/service/AbstractWorkerService.java @@ -8,27 +8,44 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.powsybl.commons.PowsyblException; +import com.powsybl.commons.io.FileUtil; import com.powsybl.commons.report.ReportNode; import com.powsybl.iidm.network.Network; import com.powsybl.iidm.network.VariantManagerConstants; import com.powsybl.network.store.client.NetworkStoreService; import com.powsybl.network.store.client.PreloadingStrategy; +import com.powsybl.ws.commons.ZipUtils; import org.apache.commons.lang3.StringUtils; import org.gridsuite.computation.ComputationException; +import org.gridsuite.computation.s3.ComputationS3Service; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.web.server.ResponseStatusException; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; import java.util.Map; import java.util.UUID; -import java.util.concurrent.*; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; +import static org.gridsuite.computation.s3.ComputationS3Service.S3_DELIMITER; +import static org.gridsuite.computation.s3.ComputationS3Service.S3_SERVICE_NOT_AVAILABLE_MESSAGE; +import static org.gridsuite.computation.service.NotificationService.HEADER_ERROR_MESSAGE; + /** * @author Mathieu Deharbe * @param powsybl Result class specific to the computation @@ -39,6 +56,9 @@ public abstract class AbstractWorkerService, P, S extends AbstractComputationResultService> { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractWorkerService.class); + @Value("${powsybl-ws.s3.subpath.prefix:}${debug-subpath:debug}") + private String debugRootPath; + protected final Lock lockRunAndCancel = new ReentrantLock(); protected final ObjectMapper objectMapper; protected final NetworkStoreService networkStoreService; @@ -50,10 +70,23 @@ public abstract class AbstractWorkerService cancelComputationRequests = new ConcurrentHashMap<>(); protected final S resultService; + protected final ComputationS3Service computationS3Service; + + protected AbstractWorkerService(NetworkStoreService networkStoreService, + NotificationService notificationService, + ReportService reportService, + S resultService, + ExecutionService executionService, + AbstractComputationObserver observer, + ObjectMapper objectMapper) { + this(networkStoreService, notificationService, reportService, resultService, null, executionService, observer, objectMapper); + } + protected AbstractWorkerService(NetworkStoreService networkStoreService, NotificationService notificationService, ReportService reportService, S resultService, + ComputationS3Service computationS3Service, ExecutionService executionService, AbstractComputationObserver observer, ObjectMapper objectMapper) { @@ -61,6 +94,7 @@ protected AbstractWorkerService(NetworkStoreService networkStoreService, this.notificationService = notificationService; this.reportService = reportService; this.resultService = resultService; + this.computationS3Service = computationS3Service; this.executionService = executionService; this.observer = observer; this.objectMapper = objectMapper; @@ -148,6 +182,9 @@ public Consumer> consumeRun() { this.handleNonCancellationException(resultContext, e, rootReporter); throw new ComputationException(String.format("%s: %s", NotificationService.getFailedMessage(getComputationType()), e.getMessage()), e.getCause()); } finally { + if (Boolean.TRUE.equals(resultContext.getRunContext().getDebug())) { + processDebug(resultContext); + } clean(resultContext); } }; @@ -160,6 +197,54 @@ public Consumer> consumeRun() { protected void clean(AbstractResultContext resultContext) { futures.remove(resultContext.getResultUuid()); cancelComputationRequests.remove(resultContext.getResultUuid()); + + // run in debug mode, clean debug dir + C runContext = resultContext.getRunContext(); + if (Boolean.TRUE.equals(runContext.getDebug()) && computationS3Service != null) { + removeDirectory(runContext.getDebugDir()); + } + } + + /** + * Process debug option + * @param resultContext The context of the computation + */ + protected void processDebug(AbstractResultContext resultContext) { + if (computationS3Service == null) { + sendDebugMessage(resultContext, S3_SERVICE_NOT_AVAILABLE_MESSAGE); + return; + } + + C runContext = resultContext.getRunContext(); + Path debugDir = runContext.getDebugDir(); + Path parentDir = debugDir.getParent(); + Path debugFilePath = parentDir.resolve(debugDir.getFileName().toString() + ".zip"); + String fileName = debugFilePath.getFileName().toString(); + + try { + // zip the working directory + ZipUtils.zip(debugDir, debugFilePath); + String s3Key = debugRootPath + S3_DELIMITER + fileName; + + // insert debug file path into db + resultService.saveDebugFileLocation(resultContext.getResultUuid(), s3Key); + + // upload zip file to s3 storage + computationS3Service.uploadFile(debugFilePath, s3Key, fileName, 30); + + // notify to study-server + sendDebugMessage(resultContext, null); + } catch (IOException | UncheckedIOException e) { + LOGGER.info("Error occurred while uploading debug file {}: {}", fileName, e.getMessage()); + sendDebugMessage(resultContext, e.getMessage()); + } finally { + // delete debug file + try { + Files.delete(debugFilePath); + } catch (IOException e) { + LOGGER.info("Error occurred while deleting debug file {}: {}", fileName, e.getMessage()); + } + } } /** @@ -187,12 +272,27 @@ protected void sendResultMessage(AbstractResultContext resultContext, R ignor resultContext.getRunContext().getUserId(), null); } + private void sendDebugMessage(AbstractResultContext resultContext, @Nullable String messageError) { + Map resultHeaders = new HashMap<>(); + + // --- attach debug to result headers --- // + resultHeaders.put(HEADER_ERROR_MESSAGE, messageError); + + notificationService.sendDebugMessage(resultContext.getResultUuid(), resultContext.getRunContext().getReceiver(), + resultContext.getRunContext().getUserId(), resultHeaders); + } + /** * Do some extra task before running the computation, e.g. print log or init extra data for the run context - * @param ignoredRunContext This context may be used for further computation in overriding classes + * @param runContext The run context of the computation */ - protected void preRun(C ignoredRunContext) { + protected void preRun(C runContext) { LOGGER.info("Run {} computation...", getComputationType()); + + // run in debug mode, create debug dir + if (Boolean.TRUE.equals(runContext.getDebug()) && computationS3Service != null) { + runContext.setDebugDir(createDebugDir()); + } } protected R run(C runContext, UUID resultUuid, AtomicReference rootReporter) { @@ -257,4 +357,31 @@ protected CompletableFuture runAsync( protected abstract String getComputationType(); protected abstract CompletableFuture getCompletableFuture(C runContext, String provider, UUID resultUuid); + + private Path createDebugDir() { + Path localDir = executionService.getComputationManager().getLocalDir(); + try { + String debugDirPrefix = buildComputationDirPrefix() + "debug_"; + return Files.createTempDirectory(localDir, debugDirPrefix); + } catch (IOException e) { + throw new UncheckedIOException(String.format("Error occurred while creating a debug directory inside the local directory %s", + localDir.toAbsolutePath()), e); + } + } + + protected String buildComputationDirPrefix() { + return getComputationType().replaceAll("\\s+", "_").toLowerCase() + "_"; + } + + protected void removeDirectory(Path dir) { + if (dir != null) { + try { + FileUtil.removeDir(dir); + } catch (IOException e) { + LOGGER.error(String.format("%s: Error occurred while removing directory %s", getComputationType(), dir.toAbsolutePath()), e); + } + } else { + LOGGER.info("{}: No directory to clean", getComputationType()); + } + } } diff --git a/src/main/java/org/gridsuite/computation/service/NotificationService.java b/src/main/java/org/gridsuite/computation/service/NotificationService.java index d3aa71d..fa07fe4 100644 --- a/src/main/java/org/gridsuite/computation/service/NotificationService.java +++ b/src/main/java/org/gridsuite/computation/service/NotificationService.java @@ -46,6 +46,8 @@ public class NotificationService { public static final String HEADER_PROVIDER = "provider"; public static final String HEADER_MESSAGE = "message"; public static final String HEADER_USER_ID = "userId"; + public static final String HEADER_DEBUG = "debug"; + public static final String HEADER_ERROR_MESSAGE = "errorMessage"; public static final String SENDING_MESSAGE = "Sending message : {}"; @@ -67,6 +69,19 @@ public void sendCancelMessage(Message message) { publisher.send(publishPrefix + "Cancel-out-0", message); } + @PostCompletion + public void sendDebugMessage(UUID resultUuid, String receiver, String userId, @Nullable Map additionalHeaders) { + MessageBuilder builder = MessageBuilder + .withPayload("") + .setHeader(HEADER_RESULT_UUID, resultUuid.toString()) + .setHeader(HEADER_RECEIVER, receiver) + .setHeader(HEADER_USER_ID, userId) + .copyHeaders(additionalHeaders); + Message message = builder.build(); + RESULT_MESSAGE_LOGGER.debug(SENDING_MESSAGE, message); + publisher.send(publishPrefix + "Debug-out-0", message); + } + @PostCompletion public void sendResultMessage(UUID resultUuid, String receiver, String userId, @Nullable Map additionalHeaders) { MessageBuilder builder = MessageBuilder diff --git a/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 0000000..00f723d --- /dev/null +++ b/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,2 @@ +# AutoConfigureCache auto-configuration imports +org.gridsuite.computation.s3.S3AutoConfiguration \ No newline at end of file diff --git a/src/test/java/org/gridsuite/computation/ComputationTest.java b/src/test/java/org/gridsuite/computation/ComputationTest.java index 5572102..8eff721 100644 --- a/src/test/java/org/gridsuite/computation/ComputationTest.java +++ b/src/test/java/org/gridsuite/computation/ComputationTest.java @@ -7,10 +7,16 @@ package org.gridsuite.computation; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import com.powsybl.commons.PowsyblException; +import com.powsybl.computation.local.LocalComputationConfig; +import com.powsybl.computation.local.LocalComputationManager; import com.powsybl.iidm.network.Network; import com.powsybl.iidm.network.VariantManager; import com.powsybl.network.store.client.NetworkStoreService; import com.powsybl.network.store.client.PreloadingStrategy; +import com.powsybl.ws.commons.ZipUtils; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.ObservationRegistry; @@ -20,25 +26,42 @@ import lombok.extern.slf4j.Slf4j; import org.assertj.core.api.WithAssertions; import org.gridsuite.computation.dto.ReportInfos; +import org.gridsuite.computation.s3.ComputationS3Service; +import org.gridsuite.computation.s3.S3InputStreamInfos; import org.gridsuite.computation.service.*; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.Mockito; +import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.cloud.stream.function.StreamBridge; +import org.springframework.core.io.InputStreamResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; +import java.util.concurrent.ForkJoinPool; +import static org.gridsuite.computation.s3.ComputationS3Service.S3_DELIMITER; +import static org.gridsuite.computation.s3.ComputationS3Service.S3_SERVICE_NOT_AVAILABLE_MESSAGE; import static org.gridsuite.computation.service.NotificationService.*; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.*; @@ -48,13 +71,23 @@ @Slf4j class ComputationTest implements WithAssertions { private static final String COMPUTATION_TYPE = "mockComputation"; + public static final String S3_DEBUG_SUBPATH = "debug"; + + public static final String WORKING_DIR = "test"; + public static final String S3_DEBUG_FILE_ZIP = WORKING_DIR + ".zip"; + public static final String S3_KEY = S3_DEBUG_SUBPATH + S3_DELIMITER + S3_DEBUG_FILE_ZIP; + + protected FileSystem fileSystem; + protected Path tmpDir; + @Mock private VariantManager variantManager; @Mock private NetworkStoreService networkStoreService; @Mock private ReportService reportService; - private final ExecutionService executionService = new ExecutionService(); + @Mock + private ExecutionService executionService; private final UuidGeneratorService uuidGeneratorService = new UuidGeneratorService(); @Mock private StreamBridge publisher; @@ -63,6 +96,8 @@ class ComputationTest implements WithAssertions { private ObjectMapper objectMapper; @Mock private Network network; + @Mock + private ComputationS3Service computationS3Service; private enum MockComputationStatus { NOT_DONE, @@ -129,8 +164,8 @@ protected MockComputationResultContext(UUID resultUuid, MockComputationRunContex } private static class MockComputationService extends AbstractComputationService { - protected MockComputationService(NotificationService notificationService, MockComputationResultService resultService, ObjectMapper objectMapper, UuidGeneratorService uuidGeneratorService, String defaultProvider) { - super(notificationService, resultService, objectMapper, uuidGeneratorService, defaultProvider); + protected MockComputationService(NotificationService notificationService, MockComputationResultService resultService, ComputationS3Service computationS3Service, ObjectMapper objectMapper, UuidGeneratorService uuidGeneratorService, String defaultProvider) { + super(notificationService, resultService, computationS3Service, objectMapper, uuidGeneratorService, defaultProvider); } @Override @@ -152,8 +187,8 @@ private enum ComputationResultWanted { } private static class MockComputationWorkerService extends AbstractWorkerService { - protected MockComputationWorkerService(NetworkStoreService networkStoreService, NotificationService notificationService, ReportService reportService, MockComputationResultService resultService, ExecutionService executionService, AbstractComputationObserver observer, ObjectMapper objectMapper) { - super(networkStoreService, notificationService, reportService, resultService, executionService, observer, objectMapper); + protected MockComputationWorkerService(NetworkStoreService networkStoreService, NotificationService notificationService, ReportService reportService, MockComputationResultService resultService, ComputationS3Service computationS3Service, ExecutionService executionService, AbstractComputationObserver observer, ObjectMapper objectMapper) { + super(networkStoreService, notificationService, reportService, resultService, computationS3Service, executionService, observer, objectMapper); } @Override @@ -207,22 +242,27 @@ public void addFuture(UUID id, CompletableFuture future) { final String provider = "MockComputation_Provider"; Message message; MockComputationRunContext runContext; + @Spy MockComputationResultService resultService; @BeforeEach - void init() { - resultService = new MockComputationResultService(); + void init() throws IOException { + // used to initialize the computation manager + fileSystem = Jimfs.newFileSystem(Configuration.unix()); + tmpDir = Files.createDirectory(fileSystem.getPath("tmp")); + notificationService = new NotificationService(publisher); workerService = new MockComputationWorkerService( networkStoreService, notificationService, reportService, resultService, + computationS3Service, executionService, new MockComputationObserver(ObservationRegistry.create(), new SimpleMeterRegistry()), objectMapper ); - computationService = new MockComputationService(notificationService, resultService, objectMapper, uuidGeneratorService, provider); + computationService = new MockComputationService(notificationService, resultService, computationS3Service, objectMapper, uuidGeneratorService, provider); MessageBuilder builder = MessageBuilder .withPayload("") @@ -236,6 +276,11 @@ void init() { resultContext = new MockComputationResultContext(RESULT_UUID, runContext); } + @AfterEach + void tearDown() throws IOException { + fileSystem.close(); + } + private void initComputationExecution() { when(networkStoreService.getNetwork(eq(networkUuid), any(PreloadingStrategy.class))) .thenReturn(network); @@ -262,8 +307,7 @@ void testComputationFailed() { runContext.setComputationResWanted(ComputationResultWanted.FAIL); // execution / cleaning - Consumer> consumer = workerService.consumeRun(); - assertThrows(ComputationException.class, () -> consumer.accept(message)); + assertThrows(ComputationException.class, () -> workerService.consumeRun().accept(message)); assertNull(resultService.findStatus(RESULT_UUID)); } @@ -332,4 +376,157 @@ void testComputationCancelledBeforeRunReturnsNoResult() { workerService.consumeRun().accept(message); verify(notificationService.getPublisher(), times(0)).send(eq("publishResult-out-0"), isA(Message.class)); } + + @Test + void testProcessDebugWithS3Service() throws IOException { + // Setup + initComputationExecution(); + when(executionService.getComputationManager()).thenReturn(new LocalComputationManager(new LocalComputationConfig(tmpDir, 1), ForkJoinPool.commonPool())); + runContext.setComputationResWanted(ComputationResultWanted.SUCCESS); + runContext.setDebug(true); + + // Mock ZipUtils + try (var mockedStatic = mockStatic(ZipUtils.class)) { + mockedStatic.when(() -> ZipUtils.zip(any(Path.class), any(Path.class))).thenAnswer(invocation -> null); + workerService.consumeRun().accept(message); + + // Verify interactions + verify(resultService).saveDebugFileLocation(eq(RESULT_UUID), anyString()); + verify(computationS3Service).uploadFile(any(Path.class), anyString(), anyString(), eq(30)); + verify(notificationService.getPublisher(), times(1 /* for result message */)) + .send(eq("publishResult-out-0"), isA(Message.class)); + verify(notificationService.getPublisher(), times(1 /* for debug message */)) + .send(eq("publishDebug-out-0"), isA(Message.class)); + } + } + + @Test + void testConsumeRunWithoutDebug() { + // Setup + initComputationExecution(); + runContext.setComputationResWanted(ComputationResultWanted.SUCCESS); + runContext.setDebug(null); + + // Execute + workerService.consumeRun().accept(message); + + // Verify interactions + verifyNoInteractions(computationS3Service, resultService); + verify(notificationService.getPublisher(), times(1 /* only result */)) + .send(eq("publishResult-out-0"), isA(Message.class)); + verify(notificationService.getPublisher(), times(0 /* no debug */)) + .send(eq("publishDebug-out-0"), isA(Message.class)); + } + + @Test + void testProcessDebugWithoutS3Service() { + // Setup worker service without ComputationS3Service + workerService = new MockComputationWorkerService( + networkStoreService, + notificationService, + reportService, + resultService, + null, + executionService, + new MockComputationObserver(ObservationRegistry.create(), new SimpleMeterRegistry()), + objectMapper + ); + initComputationExecution(); + runContext.setComputationResWanted(ComputationResultWanted.SUCCESS); + runContext.setDebug(true); + + // Execute + workerService.consumeRun().accept(message); + + // Verify + verify(notificationService.getPublisher()).send(eq("publishDebug-out-0"), argThat((Message msg) -> + msg.getHeaders().get(HEADER_ERROR_MESSAGE).equals(S3_SERVICE_NOT_AVAILABLE_MESSAGE))); + verifyNoInteractions(computationS3Service, resultService); + } + + @Test + void testProcessDebugWithIOException() throws IOException { + // Setup + initComputationExecution(); + when(executionService.getComputationManager()).thenReturn(new LocalComputationManager(new LocalComputationConfig(tmpDir, 1), ForkJoinPool.commonPool())); + runContext.setComputationResWanted(ComputationResultWanted.SUCCESS); + runContext.setDebug(true); + + // Mock ZipUtils to throw IOException + try (var mockedStatic = mockStatic(ZipUtils.class)) { + mockedStatic.when(() -> ZipUtils.zip(any(Path.class), any(Path.class))) + .thenThrow(new UncheckedIOException("Zip error", new IOException())); + workerService.consumeRun().accept(message); + + // Verify interactions + verify(computationS3Service, never()).uploadFile(any(), any(), any(), anyInt()); + verify(resultService, never()).saveDebugFileLocation(any(), any()); + verify(notificationService.getPublisher()).send(eq("publishDebug-out-0"), argThat((Message msg) -> + msg.getHeaders().get(HEADER_ERROR_MESSAGE).equals("Zip error"))); + } + } + + @Test + void testDownloadDebugFileSuccess() throws IOException { + // Setup + String fileName = S3_DEBUG_FILE_ZIP; + long fileLength = 1024L; + ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[1024]); + S3InputStreamInfos s3InputStreamInfos = S3InputStreamInfos.builder() + .inputStream(inputStream) + .fileName(fileName) + .fileLength(fileLength) + .build(); + when(resultService.findDebugFileLocation(RESULT_UUID)).thenReturn(S3_KEY); + when(computationS3Service.downloadFile(S3_KEY)).thenReturn(s3InputStreamInfos); + + // Execute + ResponseEntity response = computationService.downloadDebugFile(RESULT_UUID); + + // Assert + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(response.getBody()).isInstanceOf(InputStreamResource.class); + assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.APPLICATION_OCTET_STREAM); + assertThat(response.getHeaders().getContentLength()).isEqualTo(fileLength); + assertThat(response.getHeaders().get(HttpHeaders.CONTENT_DISPOSITION)).contains("attachment; filename=\"" + fileName + "\""); + verify(computationS3Service).downloadFile(S3_KEY); + } + + @Test + void testDownloadDebugFileS3NotAvailable() throws IOException { + // Setup + computationService = new MockComputationService(notificationService, resultService, null, objectMapper, uuidGeneratorService, "defaultProvider"); + + // Execute & Check + assertThrows(PowsyblException.class, () -> computationService.downloadDebugFile(RESULT_UUID), "S3 service not available"); + verify(computationS3Service, never()).downloadFile(any()); + } + + @Test + void testDownloadDebugFileNotFound() throws IOException { + // Setup + when(resultService.findDebugFileLocation(RESULT_UUID)).thenReturn(null); + + // Execute + ResponseEntity response = computationService.downloadDebugFile(RESULT_UUID); + + // Check + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.NOT_FOUND); + verify(computationS3Service, never()).downloadFile(any()); + } + + @Test + void testDownloadDebugFileIOException() throws IOException { + // Setup + when(resultService.findDebugFileLocation(RESULT_UUID)).thenReturn(S3_KEY); + when(computationS3Service.downloadFile(S3_KEY)).thenThrow(new IOException("S3 error")); + + // Act + ResponseEntity response = computationService.downloadDebugFile(RESULT_UUID); + + // Assert + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.NOT_FOUND); + verify(computationS3Service).downloadFile(S3_KEY); + } + } diff --git a/src/test/java/org/gridsuite/computation/s3/ComputationS3ServiceTest.java b/src/test/java/org/gridsuite/computation/s3/ComputationS3ServiceTest.java new file mode 100644 index 0000000..a88fe56 --- /dev/null +++ b/src/test/java/org/gridsuite/computation/s3/ComputationS3ServiceTest.java @@ -0,0 +1,120 @@ +/** + * Copyright (c) 2025, RTE (http://www.rte-france.com) + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.gridsuite.computation.s3; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +/** + * @author Thang PHAM + */ +class ComputationS3ServiceTest { + + public static final String PATH_IN_S3 = "path/in/s3"; + public static final String UPLOAD_FAILED_MESSAGE = "Upload failed"; + public static final String DOWNLOAD_FAILED_MESSAGE = "Download failed"; + + private S3Client s3Client; + private ComputationS3Service computationS3Service; + + @BeforeEach + void setup() { + s3Client = mock(S3Client.class); + computationS3Service = new ComputationS3Service(s3Client, "ws-bucket"); + } + + @Test + void uploadFileShouldSendSuccessful() throws IOException { + // setup + Path tempFile = Files.createTempFile("test", ".txt"); + Files.writeString(tempFile, "Normal case"); + + // perform test + computationS3Service.uploadFile(tempFile, PATH_IN_S3, "test.txt", 30); + + // check result + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(s3Client).putObject(requestCaptor.capture(), any(RequestBody.class)); + PutObjectRequest actualRequest = requestCaptor.getValue(); + + assertThat(actualRequest.bucket()).isEqualTo("ws-bucket"); + assertThat(actualRequest.key()).isEqualTo(PATH_IN_S3); + assertThat(actualRequest.metadata()).containsEntry(ComputationS3Service.METADATA_FILE_NAME, "test.txt"); + assertThat(actualRequest.tagging()).isEqualTo("expire-after-minutes=30"); + } + + @Test + void uploadFileShouldThrowException() throws IOException { + // setup + Path tempFile = Files.createTempFile("test", ".txt"); + Files.writeString(tempFile, "Error case"); + + // mock exception + when(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenThrow(S3Exception.builder().message(UPLOAD_FAILED_MESSAGE).build()); + + // perform test and check + assertThatThrownBy(() -> computationS3Service.uploadFile(tempFile, "key", "name.txt", null)) + .isInstanceOf(IOException.class) + .hasMessageContaining(UPLOAD_FAILED_MESSAGE); + } + + @Test + void downloadFileShouldReturnInfos() throws IOException { + // setup + GetObjectResponse response = GetObjectResponse.builder() + .metadata(Map.of(ComputationS3Service.METADATA_FILE_NAME, "download.txt")) + .contentLength(4086L) + .build(); + + ResponseInputStream mockedStream = + new ResponseInputStream<>(response, new ByteArrayInputStream("data".getBytes())); + + // mock return + when(s3Client.getObject(any(GetObjectRequest.class))) + .thenReturn(mockedStream); + + // perform test + S3InputStreamInfos result = computationS3Service.downloadFile(PATH_IN_S3); + + // check result + assertThat(result.getFileName()).isEqualTo("download.txt"); + assertThat(result.getFileLength()).isEqualTo(4086L); + assertThat(result.getInputStream()).isNotNull(); + } + + @Test + void downloadFileShouldThrowException() { + // setup + when(s3Client.getObject(any(GetObjectRequest.class))) + .thenThrow(S3Exception.builder().message(DOWNLOAD_FAILED_MESSAGE).build()); + + // perform test and check + assertThatThrownBy(() -> computationS3Service.downloadFile("bad-key")) + .isInstanceOf(IOException.class) + .hasMessageContaining(DOWNLOAD_FAILED_MESSAGE); + } +} diff --git a/src/test/java/org/gridsuite/computation/s3/S3AutoConfigurationTest.java b/src/test/java/org/gridsuite/computation/s3/S3AutoConfigurationTest.java new file mode 100644 index 0000000..cc847ab --- /dev/null +++ b/src/test/java/org/gridsuite/computation/s3/S3AutoConfigurationTest.java @@ -0,0 +1,51 @@ +/** + * Copyright (c) 2025, RTE (http://www.rte-france.com) + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +package org.gridsuite.computation.s3; + +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import software.amazon.awssdk.services.s3.S3Client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * @author Thang PHAM + */ +class S3AutoConfigurationTest { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(S3AutoConfiguration.class)) + .withBean(S3Client.class, () -> mock(S3Client.class)); + + @Test + void s3ServiceBeanShouldBeCreatedWhenS3Enabled() { + contextRunner + .withPropertyValues( + "computation.s3.enabled=true", + "spring.cloud.aws.bucket=test-bucket" + ) + .run(context -> { + assertThat(context).hasSingleBean(ComputationS3Service.class); + ComputationS3Service service = context.getBean(ComputationS3Service.class); + assertThat(service).isNotNull(); + }); + } + + @Test + void s3ServiceBeanShouldNotBeCreatedWhenS3EnabledMissingOrFalse() { + contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ComputationS3Service.class)); + + contextRunner + .withPropertyValues("computation.s3.enabled=false") + .run(context -> assertThat(context).doesNotHaveBean(ComputationS3Service.class)); + } +} +