Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ open class CsmApiConfiguration {

@Bean(name = ["csm-in-process-event-executor"])
open fun inProcessEventHandlerExecutor(): Executor =
// TODO A better strategy could be with a limited core pool size off an unbounded queue ?
Executors.newCachedThreadPool(
BasicThreadFactory.builder().namingPattern("csm-event-handler-%d").build()
)
Expand Down Expand Up @@ -92,8 +91,8 @@ class YamlMessageConverter(objectMapper: ObjectMapper) :
constructor() : this(yamlObjectMapper())

override fun setObjectMapper(objectMapper: ObjectMapper) {
if (objectMapper.factory !is YAMLFactory) {
throw IllegalArgumentException("ObjectMapper must be configured with YAMLFactory")
require(objectMapper.factory is YAMLFactory) {
"ObjectMapper must be configured with YAMLFactory"
}
super.setObjectMapper(objectMapper)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ open class CsmOpenAPIConfiguration(val csmPlatformProperties: CsmPlatformPropert
val openApiYamlContent =
openApiYamlInputStream.use { it.bufferedReader().use(BufferedReader::readText) }
val openApiYamlParseResult = OpenAPIV3Parser().readContents(openApiYamlContent)
if (!openApiYamlParseResult.messages.isNullOrEmpty()) {
throw IllegalStateException(
"Unable to parse OpenAPI definition from 'classpath:/static/openapi.yaml' : " +
openApiYamlParseResult.messages
)

check(openApiYamlParseResult.messages.isNullOrEmpty()) {
"Unable to parse OpenAPI definition from 'classpath:/static/openapi.yaml' : " +
openApiYamlParseResult.messages
}

val openAPI =
openApiYamlParseResult.openAPI
?: throw IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ private const val MIN_HASH_LENGTH = 0
private const val ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789"

fun generateId(scope: String, prependPrefix: String? = null): String {
if (scope.isBlank()) {
throw IllegalArgumentException("scope must not be blank")
}

require(scope.isNotBlank()) { "scope must not be blank" }

// We do not intend to decode generated IDs afterwards => we can safely generate a unique salt.
// This will give us different ids even with equal numbers to encode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class MonitorServiceAspect(
@Pointcut(
"within(@org.springframework.web.bind.annotation.RestController *) && within(com.cosmotech..*Controller)"
)
@Suppress("EmptyFunctionBlock")
fun cosmotechPointcut() {}
fun cosmotechPointcut() {
// Empty function block to define a pointcut
}

@Before("cosmotechPointcut()")
fun monitorBefore(joinPoint: JoinPoint) {
Expand Down
6 changes: 2 additions & 4 deletions common/src/main/kotlin/com/cosmotech/common/rbac/CsmRbac.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ open class CsmRbac(
// Check for duplicate identities
val accessControls = mutableListOf<String>()
objectSecurity.accessControlList.forEach {
if (accessControls.contains(it.id)) {
throw IllegalArgumentException(
"Entity ${it.id} is referenced multiple times in the security"
)
require(!(accessControls.contains(it.id))) {
"Entity ${it.id} is referenced multiple times in the security"
}
accessControls.add(it.id)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,13 @@ inline fun <reified T> T.compareToAndMutateIfNeeded(
if (changed) {
membersChanged.add(member.name)
if (mutateIfChanged) {
if (member !is KMutableProperty) {
throw IllegalArgumentException(
"Detected change but cannot mutate " +
"this object because property ${member.name} " +
"(on class ${T::class}) is not mutable. " +
"Either exclude this field or call this function with " +
"mutateIfChanged=false to view the changes detected"
)
require(member is KMutableProperty) {
"""Detected change but cannot mutate this object
because property ${member.name} (on class ${T::class}) is not mutable.
Either exclude this field or call this function with mutateIfChanged=false
to view the changes detected"""
}

member.setter.call(this, newValue)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,34 @@ fun getCurrentAccountIdentifier(configuration: CsmPlatformProperties): String {
}

fun getCurrentAccountGroups(configuration: CsmPlatformProperties): List<String> {
return (getValueFromAuthenticatedToken(configuration) {
return getValueFromAuthenticatedToken(configuration) {
try {
val jwt = JWTParser.parse(it)
jwt.jwtClaimsSet.getStringListClaim(configuration.authorization.groupJwtClaim)
} catch (e: ParseException) {
JSONObjectUtils.parse(it)[configuration.authorization.groupJwtClaim] as List<String>
}
} ?: emptyList())
}
}

fun getCurrentAuthenticatedRoles(configuration: CsmPlatformProperties): List<String> {
return (getValueFromAuthenticatedToken(configuration) {
return getValueFromAuthenticatedToken(configuration) {
try {
val jwt = JWTParser.parse(it)
jwt.jwtClaimsSet.getStringListClaim(configuration.authorization.rolesJwtClaim)
} catch (e: ParseException) {
JSONObjectUtils.parse(it)[configuration.authorization.rolesJwtClaim] as List<String>
}
} ?: emptyList())
}
}

fun <T> getValueFromAuthenticatedToken(
configuration: CsmPlatformProperties,
actionLambda: (String) -> T,
): T {
if (getCurrentAuthentication() == null) {
throw IllegalStateException("User Authentication not found in Security Context")
}
val authentication = getCurrentAuthentication()
checkNotNull(authentication) { "User Authentication not found in Security Context" }

if (authentication is JwtAuthenticationToken) {
return authentication.token.tokenValue.let { actionLambda(it) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import java.io.BufferedReader
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.sql.SQLException
import kotlin.use
import org.apache.commons.lang3.StringUtils
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
Expand Down Expand Up @@ -50,10 +49,8 @@ class RelationalDatasetPartManagementService(

val tableExists = writerJdbcTemplate.existTable(datasetPart.id)

if (tableExists && !overwrite) {
throw IllegalArgumentException(
"Table ${datasetPart.id} already exists and overwrite is set to false."
)
require(!tableExists || overwrite) {
"Table ${datasetPart.id} already exists and overwrite is set to false."
}

inputStream.bufferedReader().use { reader ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ import com.cosmotech.workspace.WorkspaceApiServiceInterface
import com.cosmotech.workspace.service.toGenericSecurity
import java.io.ByteArrayOutputStream
import java.time.Instant
import kotlin.String
import kotlin.use
import org.apache.commons.lang3.StringUtils
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
Expand Down Expand Up @@ -671,16 +669,18 @@ class DatasetServiceImpl(
val query =
constructQuery(
datasetPartId,
selects,
sums,
avgs,
counts,
mins,
maxs,
offset,
limit,
groupBys,
orderBys,
QueryParams(
selects,
sums,
avgs,
counts,
mins,
maxs,
offset,
limit,
groupBys,
orderBys,
),
)

val outputStream = ByteArrayOutputStream()
Expand Down Expand Up @@ -715,28 +715,15 @@ class DatasetServiceImpl(
}
}

@Suppress("LongParameterList")
private fun constructQuery(
datasetPartId: String,
selects: List<String>?,
sums: List<String>?,
avgs: List<String>?,
counts: List<String>?,
mins: List<String>?,
maxs: List<String>?,
offset: Int?,
limit: Int?,
groupBys: List<String>?,
orderBys: List<String>?,
): String {
private fun constructQuery(datasetPartId: String, queryParams: QueryParams): String {
val tableName = "${DATASET_INPUTS_SCHEMA}.${datasetPartId.sanitizeDatasetPartId()}"

val selectClauses = constructClause(selects, null, null)
val sumClauses = constructClause(sums, AggregationType.Sum, "float8")
val avgClauses = constructClause(avgs, AggregationType.Avg, "float8")
val countClauses = constructClause(counts, AggregationType.Count, null)
val minClauses = constructClause(mins, AggregationType.Min, "float8")
val maxClauses = constructClause(maxs, AggregationType.Max, "float8")
val selectClauses = constructClause(queryParams.selects, null, null)
val sumClauses = constructClause(queryParams.sums, AggregationType.Sum, "float8")
val avgClauses = constructClause(queryParams.avgs, AggregationType.Avg, "float8")
val countClauses = constructClause(queryParams.counts, AggregationType.Count, null)
val minClauses = constructClause(queryParams.mins, AggregationType.Min, "float8")
val maxClauses = constructClause(queryParams.maxs, AggregationType.Max, "float8")

val allSelectClauses =
mutableListOf(selectClauses, sumClauses, avgClauses, countClauses, minClauses, maxClauses)
Expand All @@ -746,17 +733,17 @@ class DatasetServiceImpl(
val query =
StringBuilder("SELECT %s FROM %s ".format(allSelectClauses.ifBlank { "*" }, tableName))

val groupByClauses = constructClause(groupBys, null, null)
val groupByClauses = constructClause(queryParams.groupBys, null, null)
if (groupByClauses.isNotBlank()) {
query.append("GROUP BY $groupByClauses ")
}

val orderByClauses = constructClause(orderBys, null, null, true)
val orderByClauses = constructClause(queryParams.orderBys, null, null, true)
if (orderByClauses.isNotBlank()) {
query.append("ORDER BY $orderByClauses ")
}

addLimitOffset(limit, offset, query)
addLimitOffset(queryParams.limit, queryParams.offset, query)

logger.debug(query.toString())
return query.toString()
Expand Down Expand Up @@ -1139,6 +1126,19 @@ class DatasetServiceImpl(
}
}

class QueryParams(
val selects: List<String>?,
val sums: List<String>?,
val avgs: List<String>?,
val counts: List<String>?,
val mins: List<String>?,
val maxs: List<String>?,
val offset: Int?,
val limit: Int?,
val groupBys: List<String>?,
val orderBys: List<String>?,
)

enum class AggregationType {
Sum,
Avg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package com.cosmotech.metrics

import com.cosmotech.common.metrics.PersistentMetric

interface MetricsService {
fun interface MetricsService {
/**
* Store a metric in the persistent database.
*
Expand Down
70 changes: 0 additions & 70 deletions run/src/main/kotlin/com/cosmotech/run/RunContainerFactory.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.cosmotech.organization.domain.Organization
import com.cosmotech.run.container.BASIC_SIZING
import com.cosmotech.run.container.HIGH_CPU_SIZING
import com.cosmotech.run.container.HIGH_MEMORY_SIZING
import com.cosmotech.run.container.Sizing
import com.cosmotech.run.container.StartInfo
import com.cosmotech.run.container.toContainerResourceSizing
import com.cosmotech.run.container.toSizing
Expand Down Expand Up @@ -211,75 +210,6 @@ class RunContainerFactory(
)
}

internal fun buildSingleContainerStart(
containerName: String,
imageName: String,
jobId: String,
imageRegistry: String = "",
imageVersion: String = "latest",
containerEnvVars: Map<String, String>,
workflowType: String,
nodeLabel: String = NODE_LABEL_DEFAULT,
): RunStartContainers {

var defaultSizing = BASIC_SIZING

if (nodeLabel.isNotBlank()) {
defaultSizing = LABEL_SIZING[nodeLabel] ?: BASIC_SIZING
}

val container =
buildSimpleContainer(
imageRegistry,
imageName,
imageVersion,
defaultSizing,
containerName,
containerEnvVars,
nodeLabel,
)

val generateName = "${jobId}$GENERATE_NAME_SUFFIX".sanitizeForKubernetes()

return RunStartContainers(
generateName = generateName,
nodeLabel = nodeLabel.plus(NODE_LABEL_SUFFIX),
containers = listOf(container),
csmSimulationId = jobId,
labels =
mapOf(
CSM_JOB_ID_LABEL_KEY to jobId,
WORKFLOW_TYPE_LABEL to workflowType,
ORGANIZATION_ID_LABEL to "none",
WORKSPACE_ID_LABEL to "none",
RUNNER_ID_LABEL to "none",
),
)
}

internal fun buildSimpleContainer(
imageRegistry: String,
imageName: String,
imageVersion: String,
nodeSizing: Sizing,
containerName: String,
containerEnvVars: Map<String, String>,
nodeLabel: String,
): RunContainer {

val envVars = getMinimalCommonEnvVars(csmPlatformProperties).toMutableMap()
envVars.putAll(containerEnvVars)

return RunContainer(
name = containerName,
image = getImageName(imageRegistry, imageName, imageVersion),
dependencies = listOf(CSM_DAG_ROOT),
envVars = envVars,
nodeLabel = nodeLabel,
runSizing = nodeSizing.toContainerResourceSizing(),
)
}

private fun getRunTemplate(solution: Solution, runTemplateId: String): RunTemplate {
return solution.runTemplates.find { runTemplate -> runTemplate.id == runTemplateId }
?: throw IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.cosmotech.runner
import com.cosmotech.common.events.TriggerRunnerEvent
import org.springframework.context.event.EventListener

interface RunnerEventServiceInterface {
fun interface RunnerEventServiceInterface {

@EventListener(TriggerRunnerEvent::class) fun startNewRun(triggerEvent: TriggerRunnerEvent)
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,12 @@ class RunnerService(
"Runner $runnerId not found in workspace ${workspace!!.id} and organization ${organization!!.id}"
)
}
if (runner.lastRunInfo.lastRunId != null) {
if (
runner.lastRunInfo.lastRunStatus != LastRunStatus.Failed ||
runner.lastRunInfo.lastRunStatus != LastRunStatus.Successful
) {
runner = updateRunnerStatus(runner)
}
if (
runner.lastRunInfo.lastRunId != null &&
(runner.lastRunInfo.lastRunStatus != LastRunStatus.Failed ||
runner.lastRunInfo.lastRunStatus != LastRunStatus.Successful)
) {
runner = updateRunnerStatus(runner)
}
updateSecurityVisibility(runner)
return RunnerInstance().initializeFrom(runner).userHasPermission(PERMISSION_READ)
Expand Down
Loading