diff --git a/core/src/main/java/fr/sncf/osrd/api/InfraManager.kt b/core/src/main/java/fr/sncf/osrd/api/InfraManager.kt index cccfd95a2ad..8e40464eae1 100644 --- a/core/src/main/java/fr/sncf/osrd/api/InfraManager.kt +++ b/core/src/main/java/fr/sncf/osrd/api/InfraManager.kt @@ -4,6 +4,7 @@ import fr.sncf.osrd.parseRJSInfra import fr.sncf.osrd.railjson.schema.infra.RJSInfra import fr.sncf.osrd.reporting.exceptions.ErrorType import fr.sncf.osrd.reporting.exceptions.OSRDError +import fr.sncf.osrd.utils.compressToZip import fr.sncf.osrd.utils.jacoco.ExcludeFromGeneratedCodeCoverage import io.opentelemetry.api.trace.SpanKind import io.opentelemetry.instrumentation.annotations.WithSpan @@ -14,8 +15,12 @@ import okhttp3.OkHttpClient import org.slf4j.Logger import org.slf4j.LoggerFactory -class InfraManager(baseUrl: String, authorizationToken: String?, httpClient: OkHttpClient) : - APIClient(baseUrl, authorizationToken, httpClient), InfraProvider { +class InfraManager( + baseUrl: String, + authorizationToken: String?, + httpClient: OkHttpClient, + val s3Context: S3Context? = null, +) : APIClient(baseUrl, authorizationToken, httpClient), InfraProvider { private val infraCache = ConcurrentHashMap() private val signalingSimulator = makeSignalingSimulator() @@ -118,6 +123,16 @@ class InfraManager(baseUrl: String, authorizationToken: String?, httpClient: OkH cacheEntry.version = version checkNotNull(response.body) { "missing body in railjson response" } rjsInfra = RJSInfra.adapter.fromJson(response.body.source())!! + + // Save railjson to s3 if available, for better reproducibility. + // This is done on a different thread while the infra is parsed (on a single + // thread), it should not take any extra time. + s3Context?.writeFileIfMissing("stdcm/infras/$infraId-$version.railjson.zip") { + RJSInfra.adapter + .toJson(rjsInfra) + .encodeToByteArray() + .compressToZip("$infraId-$version.railjson") + } } // Parse railjson into a proper infra diff --git a/core/src/main/java/fr/sncf/osrd/api/S3Context.kt b/core/src/main/java/fr/sncf/osrd/api/S3Context.kt index a3cb63129a3..81f10706bb5 100644 --- a/core/src/main/java/fr/sncf/osrd/api/S3Context.kt +++ b/core/src/main/java/fr/sncf/osrd/api/S3Context.kt @@ -24,6 +24,9 @@ val s3Logger = KotlinLogging.logger {} * * Note: as this S3 is only used to generate data that helps with viewing and debugging, errors are * never critical. All operations are wrapped into try/catch blocks with error logging. + * + * Some functions take either [ByteArray] or [String] as input, the versions that aren't used have + * been skipped but may be added later on. */ data class S3Context( val s3Client: S3Client, @@ -32,34 +35,49 @@ data class S3Context( val asyncDispatcher: CoroutineDispatcher = Dispatchers.IO, ) { - /** Write a new file for a given stdcm request. */ + /** Write a new file. */ @WithSpan(value = "Writing S3 file", kind = SpanKind.SERVER) - fun writeSTDCMFile(fileName: String, content: String) { + private fun writeFile(fileName: String, requestBody: RequestBody) { runAsync { try { - val traceId = Span.current().spanContext.traceId - s3Logger.info { "Request $traceId: writing $fileName" } + s3Logger.info { "Writing $fileName" } val putObjectRequest = - PutObjectRequest.builder() - .bucket(bucketName) - .key("stdcm/requests/$traceId/$fileName") - .build() - - s3Client.putObject(putObjectRequest, RequestBody.fromString(content)) + PutObjectRequest.builder().bucket(bucketName).key(fileName).build() + s3Client.putObject(putObjectRequest, requestBody) } catch (e: Exception) { s3Logger.error { e } } } } + /** Write a new file for a given stdcm request. */ + private fun writeSTDCMFile(fileName: String, requestBody: RequestBody) { + val traceId = Span.current().spanContext.traceId + val path = "stdcm/requests/$traceId/$fileName" + writeFile(path, requestBody) + } + /** * Write a new file for a given stdcm request, with a dedicated function to generate the * content. Used for safe call syntax (?.) that doesn't generate the data if the S3Context is * null. The generating method is also delegated to a distinct thread, this entire method call * is non-blocking. */ - fun writeSTDCMFile(fileName: String, generateContent: () -> String) { - runAsync { writeSTDCMFile(fileName, generateContent()) } + fun writeSTDCMFile(fileName: String, generateContent: () -> String?) { + runAsync { generateContent()?.let { writeSTDCMFile(fileName, RequestBody.fromString(it)) } } + } + + /** + * Write a new file with a dedicated function to generate the content. Check first that the file + * isn't already on the s3, stop otherwise. The generating method may return null (e.g. when + * some error happened), in which case nothing is uploaded. + */ + fun writeFileIfMissing(fileName: String, generateContent: () -> ByteArray?) { + runAsync { + if (!fileExists(fileName)) { + generateContent()?.let { writeFile(fileName, RequestBody.fromBytes(it)) } + } + } } /** Returns true if the file exists. */ diff --git a/core/src/main/java/fr/sncf/osrd/api/TimetableCacheManager.kt b/core/src/main/java/fr/sncf/osrd/api/TimetableCacheManager.kt index bdd4ba80621..ffe16c01140 100644 --- a/core/src/main/java/fr/sncf/osrd/api/TimetableCacheManager.kt +++ b/core/src/main/java/fr/sncf/osrd/api/TimetableCacheManager.kt @@ -5,15 +5,15 @@ import com.google.common.collect.Range import com.google.common.collect.RangeSet import com.google.common.collect.TreeRangeSet import fr.sncf.osrd.sim_infra.api.ZoneId +import fr.sncf.osrd.utils.compress +import fr.sncf.osrd.utils.compressToZip +import fr.sncf.osrd.utils.decompress import io.lettuce.core.api.StatefulRedisConnection import io.opentelemetry.api.trace.Span import io.opentelemetry.api.trace.SpanKind import io.opentelemetry.instrumentation.annotations.WithSpan -import java.io.ByteArrayOutputStream import java.nio.file.Files import java.util.concurrent.ConcurrentHashMap -import java.util.zip.GZIPInputStream -import java.util.zip.GZIPOutputStream import kotlin.io.path.Path import kotlin.io.path.exists import kotlin.io.path.readBytes @@ -27,8 +27,6 @@ import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.cbor.Cbor import org.slf4j.LoggerFactory -import software.amazon.awssdk.core.sync.RequestBody -import software.amazon.awssdk.services.s3.model.PutObjectRequest typealias TimetableId = Int @@ -292,33 +290,18 @@ class TimetableCacheManager( private fun saveToS3(timetableId: TimetableId, requirements: STDCMTimetableData) { if (s3Context == null) return - s3Context.runAsync { - val objectPath = "stdcm/saved_timetables/$timetableId.cbor" - if (s3Context.fileExists(objectPath)) return@runAsync - + val objectPath = "stdcm/saved_timetables/$timetableId.cbor.zip" + s3Context.writeFileIfMissing(objectPath) { try { - val putObjectRequest = - PutObjectRequest.builder().bucket(s3Context.bucketName).key(objectPath).build() - val serializable = requirements.toSerializable() val cbor = Cbor {} val serializer = STDCMTimetableData.SerializableMap.serializer() val bytes = cbor.encodeToByteArray(serializer, serializable) - - s3Context.s3Client.putObject(putObjectRequest, RequestBody.fromBytes(bytes)) + bytes.compressToZip("$timetableId.cbor") } catch (e: Exception) { logger.error("failed to save timetable to s3", e) + null } } } } - -fun ByteArray.compress(): ByteArray { - val outputStream = ByteArrayOutputStream(this.size) - GZIPOutputStream(outputStream).use { it.write(this) } - return outputStream.toByteArray() -} - -fun ByteArray.decompress(): ByteArray { - return GZIPInputStream(this.inputStream()).use { it.readBytes() } -} diff --git a/core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt b/core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt index fc7a4275b7c..fc82aa9b9eb 100644 --- a/core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt +++ b/core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt @@ -122,7 +122,7 @@ class WorkerCommand : CliCommand { val infraId = WORKER_KEY.split("-").first() val timetableId = WORKER_KEY.split("-").getOrNull(1)?.toInt() - val infraManager = InfraManager(editoastUrl!!, editoastAuthorization, httpClient) + val infraManager = InfraManager(editoastUrl!!, editoastAuthorization, httpClient, s3Context) val timetableCache = TimetableCacheManager( TimetableDownloader( diff --git a/core/src/main/java/fr/sncf/osrd/utils/CompressionUtils.kt b/core/src/main/java/fr/sncf/osrd/utils/CompressionUtils.kt new file mode 100644 index 00000000000..cfaee116045 --- /dev/null +++ b/core/src/main/java/fr/sncf/osrd/utils/CompressionUtils.kt @@ -0,0 +1,30 @@ +package fr.sncf.osrd.utils + +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPInputStream +import java.util.zip.GZIPOutputStream +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream + +/** GZIP compression */ +fun ByteArray.compress(): ByteArray { + val outputStream = ByteArrayOutputStream(this.size) + GZIPOutputStream(outputStream).use { it.write(this) } + return outputStream.toByteArray() +} + +/** GZIP decompression */ +fun ByteArray.decompress(): ByteArray { + return GZIPInputStream(this.inputStream()).use { it.readBytes() } +} + +/** ZIP compression */ +fun ByteArray.compressToZip(innerFilename: String): ByteArray { + val outputStream = ByteArrayOutputStream(this.size) + ZipOutputStream(outputStream).use { + val entry = ZipEntry(innerFilename) + it.putNextEntry(entry) + it.write(this) + } + return outputStream.toByteArray() +} diff --git a/core/src/test/java/fr/sncf/osrd/api/ApiTest.java b/core/src/test/java/fr/sncf/osrd/api/ApiTest.java index c44219661b3..c0b60001a48 100644 --- a/core/src/test/java/fr/sncf/osrd/api/ApiTest.java +++ b/core/src/test/java/fr/sncf/osrd/api/ApiTest.java @@ -59,7 +59,7 @@ private static String parseMockRequest(Request request, String regex) throws IOE /** Setup infra handler mock */ @BeforeEach public void setUp() throws IOException { - infraManager = new InfraManager("http://test.com/", null, mockHttpClient(".*/infra/(.*)/railjson.*")); + infraManager = new InfraManager("http://test.com/", null, mockHttpClient(".*/infra/(.*)/railjson.*"), null); electricalProfileSetManager = new ElectricalProfileSetManager( "http://test.com/", null, mockHttpClient(".*/electrical_profile_set/(.*)/")); }