Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions core/src/main/java/fr/sncf/osrd/api/InfraManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import fr.sncf.osrd.parseRJSInfra
import fr.sncf.osrd.railjson.schema.infra.RJSInfra
import fr.sncf.osrd.reporting.exceptions.ErrorType
import fr.sncf.osrd.reporting.exceptions.OSRDError
import fr.sncf.osrd.utils.compressToZip
import fr.sncf.osrd.utils.jacoco.ExcludeFromGeneratedCodeCoverage
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
Expand All @@ -14,8 +15,12 @@ import okhttp3.OkHttpClient
import org.slf4j.Logger
import org.slf4j.LoggerFactory

class InfraManager(baseUrl: String, authorizationToken: String?, httpClient: OkHttpClient) :
APIClient(baseUrl, authorizationToken, httpClient), InfraProvider {
class InfraManager(
baseUrl: String,
authorizationToken: String?,
httpClient: OkHttpClient,
val s3Context: S3Context? = null,
) : APIClient(baseUrl, authorizationToken, httpClient), InfraProvider {
private val infraCache = ConcurrentHashMap<String, InfraCacheEntry>()
private val signalingSimulator = makeSignalingSimulator()

Expand Down Expand Up @@ -118,6 +123,16 @@ class InfraManager(baseUrl: String, authorizationToken: String?, httpClient: OkH
cacheEntry.version = version
checkNotNull(response.body) { "missing body in railjson response" }
rjsInfra = RJSInfra.adapter.fromJson(response.body.source())!!

// Save railjson to s3 if available, for better reproducibility.
// This is done on a different thread while the infra is parsed (on a single
// thread), it should not take any extra time.
s3Context?.writeFileIfMissing("stdcm/infras/$infraId-$version.railjson.zip") {
RJSInfra.adapter
.toJson(rjsInfra)
.encodeToByteArray()
.compressToZip("$infraId-$version.railjson")
}
}

// Parse railjson into a proper infra
Expand Down
42 changes: 30 additions & 12 deletions core/src/main/java/fr/sncf/osrd/api/S3Context.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ val s3Logger = KotlinLogging.logger {}
*
* Note: as this S3 is only used to generate data that helps with viewing and debugging, errors are
* never critical. All operations are wrapped into try/catch blocks with error logging.
*
* Some functions take either [ByteArray] or [String] as input, the versions that aren't used have
* been skipped but may be added later on.
*/
data class S3Context(
val s3Client: S3Client,
Expand All @@ -32,34 +35,49 @@ data class S3Context(
val asyncDispatcher: CoroutineDispatcher = Dispatchers.IO,
) {

/** Write a new file for a given stdcm request. */
/** Write a new file. */
@WithSpan(value = "Writing S3 file", kind = SpanKind.SERVER)
fun writeSTDCMFile(fileName: String, content: String) {
private fun writeFile(fileName: String, requestBody: RequestBody) {
runAsync {
try {
val traceId = Span.current().spanContext.traceId
s3Logger.info { "Request $traceId: writing $fileName" }
s3Logger.info { "Writing $fileName" }
val putObjectRequest =
PutObjectRequest.builder()
.bucket(bucketName)
.key("stdcm/requests/$traceId/$fileName")
.build()

s3Client.putObject(putObjectRequest, RequestBody.fromString(content))
PutObjectRequest.builder().bucket(bucketName).key(fileName).build()
s3Client.putObject(putObjectRequest, requestBody)
} catch (e: Exception) {
s3Logger.error { e }
}
}
}

/** Write a new file for a given stdcm request. */
private fun writeSTDCMFile(fileName: String, requestBody: RequestBody) {
val traceId = Span.current().spanContext.traceId
val path = "stdcm/requests/$traceId/$fileName"
writeFile(path, requestBody)
}

/**
* Write a new file for a given stdcm request, with a dedicated function to generate the
* content. Used for safe call syntax (?.) that doesn't generate the data if the S3Context is
* null. The generating method is also delegated to a distinct thread, this entire method call
* is non-blocking.
*/
fun writeSTDCMFile(fileName: String, generateContent: () -> String) {
runAsync { writeSTDCMFile(fileName, generateContent()) }
fun writeSTDCMFile(fileName: String, generateContent: () -> String?) {
runAsync { generateContent()?.let { writeSTDCMFile(fileName, RequestBody.fromString(it)) } }
}

/**
* Write a new file with a dedicated function to generate the content. Check first that the file
* isn't already on the s3, stop otherwise. The generating method may return null (e.g. when
* some error happened), in which case nothing is uploaded.
*/
fun writeFileIfMissing(fileName: String, generateContent: () -> ByteArray?) {
runAsync {
if (!fileExists(fileName)) {
generateContent()?.let { writeFile(fileName, RequestBody.fromBytes(it)) }
}
}
}

/** Returns true if the file exists. */
Expand Down
31 changes: 7 additions & 24 deletions core/src/main/java/fr/sncf/osrd/api/TimetableCacheManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import com.google.common.collect.Range
import com.google.common.collect.RangeSet
import com.google.common.collect.TreeRangeSet
import fr.sncf.osrd.sim_infra.api.ZoneId
import fr.sncf.osrd.utils.compress
import fr.sncf.osrd.utils.compressToZip
import fr.sncf.osrd.utils.decompress
import io.lettuce.core.api.StatefulRedisConnection
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
import java.io.ByteArrayOutputStream
import java.nio.file.Files
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.GZIPInputStream
import java.util.zip.GZIPOutputStream
import kotlin.io.path.Path
import kotlin.io.path.exists
import kotlin.io.path.readBytes
Expand All @@ -27,8 +27,6 @@ import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.cbor.Cbor
import org.slf4j.LoggerFactory
import software.amazon.awssdk.core.sync.RequestBody
import software.amazon.awssdk.services.s3.model.PutObjectRequest

typealias TimetableId = Int

Expand Down Expand Up @@ -292,33 +290,18 @@ class TimetableCacheManager(
private fun saveToS3(timetableId: TimetableId, requirements: STDCMTimetableData) {
if (s3Context == null) return

s3Context.runAsync {
val objectPath = "stdcm/saved_timetables/$timetableId.cbor"
if (s3Context.fileExists(objectPath)) return@runAsync

val objectPath = "stdcm/saved_timetables/$timetableId.cbor.zip"
s3Context.writeFileIfMissing(objectPath) {
try {
val putObjectRequest =
PutObjectRequest.builder().bucket(s3Context.bucketName).key(objectPath).build()

val serializable = requirements.toSerializable()
val cbor = Cbor {}
val serializer = STDCMTimetableData.SerializableMap.serializer()
val bytes = cbor.encodeToByteArray(serializer, serializable)

s3Context.s3Client.putObject(putObjectRequest, RequestBody.fromBytes(bytes))
bytes.compressToZip("$timetableId.cbor")
} catch (e: Exception) {
logger.error("failed to save timetable to s3", e)
null
}
}
}
}

fun ByteArray.compress(): ByteArray {
val outputStream = ByteArrayOutputStream(this.size)
GZIPOutputStream(outputStream).use { it.write(this) }
return outputStream.toByteArray()
}

fun ByteArray.decompress(): ByteArray {
return GZIPInputStream(this.inputStream()).use { it.readBytes() }
}
2 changes: 1 addition & 1 deletion core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class WorkerCommand : CliCommand {

val infraId = WORKER_KEY.split("-").first()
val timetableId = WORKER_KEY.split("-").getOrNull(1)?.toInt()
val infraManager = InfraManager(editoastUrl!!, editoastAuthorization, httpClient)
val infraManager = InfraManager(editoastUrl!!, editoastAuthorization, httpClient, s3Context)
val timetableCache =
TimetableCacheManager(
TimetableDownloader(
Expand Down
30 changes: 30 additions & 0 deletions core/src/main/java/fr/sncf/osrd/utils/CompressionUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package fr.sncf.osrd.utils

import java.io.ByteArrayOutputStream
import java.util.zip.GZIPInputStream
import java.util.zip.GZIPOutputStream
import java.util.zip.ZipEntry
import java.util.zip.ZipOutputStream

/** GZIP compression */
fun ByteArray.compress(): ByteArray {
val outputStream = ByteArrayOutputStream(this.size)
GZIPOutputStream(outputStream).use { it.write(this) }
return outputStream.toByteArray()
}

/** GZIP decompression */
fun ByteArray.decompress(): ByteArray {
return GZIPInputStream(this.inputStream()).use { it.readBytes() }
}

/** ZIP compression */
fun ByteArray.compressToZip(innerFilename: String): ByteArray {
val outputStream = ByteArrayOutputStream(this.size)
ZipOutputStream(outputStream).use {
val entry = ZipEntry(innerFilename)
it.putNextEntry(entry)
it.write(this)
}
return outputStream.toByteArray()
}
2 changes: 1 addition & 1 deletion core/src/test/java/fr/sncf/osrd/api/ApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private static String parseMockRequest(Request request, String regex) throws IOE
/** Setup infra handler mock */
@BeforeEach
public void setUp() throws IOException {
infraManager = new InfraManager("http://test.com/", null, mockHttpClient(".*/infra/(.*)/railjson.*"));
infraManager = new InfraManager("http://test.com/", null, mockHttpClient(".*/infra/(.*)/railjson.*"), null);
electricalProfileSetManager = new ElectricalProfileSetManager(
"http://test.com/", null, mockHttpClient(".*/electrical_profile_set/(.*)/"));
}
Expand Down
Loading