Skip to content

Commit ba377da

Browse files
committed
Updated checker and added tests for strict mode
1 parent 1ed006e commit ba377da

File tree

8 files changed

+916
-25
lines changed

8 files changed

+916
-25
lines changed

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/StrictMode.kt

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@ import org.jetbrains.kotlin.compiler.plugin.AbstractCliOption
88
import org.jetbrains.kotlin.compiler.plugin.CliOption
99
import org.jetbrains.kotlin.config.CompilerConfiguration
1010
import org.jetbrains.kotlin.config.CompilerConfigurationKey
11+
import kotlin.text.lowercase
1112

1213
enum class StrictMode {
1314
NONE, WARNING, ERROR;
15+
16+
companion object {
17+
fun fromCli(value: String): StrictMode? {
18+
return when (value.lowercase()) {
19+
"none" -> NONE
20+
"warning" -> WARNING
21+
"error" -> ERROR
22+
else -> null
23+
}
24+
}
25+
}
1426
}
1527

1628
data class StrictModeAggregator(
@@ -123,20 +135,11 @@ object StrictModeCliOptions {
123135

124136
fun AbstractCliOption.processAsStrictModeOption(value: String, configuration: CompilerConfiguration): Boolean {
125137
StrictModeCliOptions.configurationMapper[this]?.let { key ->
126-
value.toStrictMode()?.let { mode ->
138+
StrictMode.fromCli(value)?.let { mode ->
127139
configuration.put(key, mode)
128140
return true
129141
}
130142
}
131143

132144
return false
133145
}
134-
135-
private fun String.toStrictMode(): StrictMode? {
136-
return when (lowercase()) {
137-
"none" -> StrictMode.NONE
138-
"warning" -> StrictMode.WARNING
139-
"error" -> StrictMode.ERROR
140-
else -> null
141-
}
142-
}

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/checkers/FirRpcStrictModeClassChecker.kt

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@ import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
1818
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
1919
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirClassChecker
2020
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
21+
import org.jetbrains.kotlin.fir.analysis.checkers.extractArgumentsTypeRefAndSource
22+
import org.jetbrains.kotlin.fir.analysis.checkers.toClassLikeSymbol
2123
import org.jetbrains.kotlin.fir.declarations.FirClass
2224
import org.jetbrains.kotlin.fir.declarations.FirProperty
2325
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
2426
import org.jetbrains.kotlin.fir.declarations.utils.isSuspend
2527
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
2628
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
2729
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
30+
import org.jetbrains.kotlin.fir.scopes.impl.toConeType
31+
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
2832
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
33+
import org.jetbrains.kotlin.fir.types.FirTypeRef
2934
import org.jetbrains.kotlin.fir.types.coneType
3035
import org.jetbrains.kotlin.utils.memoryOptimizedMap
3136
import org.jetbrains.kotlin.utils.memoryOptimizedPlus
@@ -101,16 +106,14 @@ class FirRpcStrictModeClassChecker(private val ctx: FirCheckersContext) : FirCla
101106

102107
val types = function.valueParameters.memoryOptimizedMap { parameter ->
103108
parameter.source to vsApi {
104-
parameter.returnTypeRef.coneType.toClassSymbolVS(context.session)
109+
parameter.returnTypeRef
105110
}
106-
} memoryOptimizedPlus (function.returnTypeRef.source to returnClassSymbol)
111+
} memoryOptimizedPlus (function.returnTypeRef.source to function.returnTypeRef)
107112

108-
types.filter { (_, symbol) ->
109-
symbol != null
110-
}.forEach { (source, symbol) ->
111-
checkSerializableTypes<FirClassSymbol<*>>(
113+
types.forEach { (source, symbol) ->
114+
checkSerializableTypes<FirClassLikeSymbol<*>>(
112115
context = context,
113-
clazz = symbol!!,
116+
typeRef = symbol,
114117
serializablePropertiesProvider = serializablePropertiesProvider,
115118
) { symbol, parents ->
116119
when (symbol.classId) {
@@ -142,28 +145,54 @@ class FirRpcStrictModeClassChecker(private val ctx: FirCheckersContext) : FirCla
142145

143146
private fun <ContextElement> checkSerializableTypes(
144147
context: CheckerContext,
145-
clazz: FirClassSymbol<*>,
148+
typeRef: FirTypeRef,
146149
serializablePropertiesProvider: FirSerializablePropertiesProvider,
147150
parentContext: List<ContextElement> = emptyList(),
148-
checker: (FirClassSymbol<*>, List<ContextElement>) -> ContextElement?,
151+
checker: (FirClassLikeSymbol<*>, List<ContextElement>) -> ContextElement?,
149152
) {
150-
val newElement = checker(clazz, parentContext)
153+
val symbol = typeRef.toClassLikeSymbol(context.session) ?: return
154+
val newElement = checker(symbol, parentContext)
151155
val nextContext = if (newElement != null) {
152156
parentContext memoryOptimizedPlus newElement
153157
} else {
154158
parentContext
155159
}
156160

157-
serializablePropertiesProvider.getSerializablePropertiesForClass(clazz)
161+
if (symbol !is FirClassSymbol<*>) {
162+
return
163+
}
164+
165+
val extracted = extractArgumentsTypeRefAndSource(typeRef)
166+
.orEmpty()
167+
.withIndex()
168+
.associate { (i, refSource) ->
169+
symbol.typeParameterSymbols[i].toConeType() to refSource.typeRef
170+
}
171+
172+
val flowProps: List<FirTypeRef> = if (symbol.classId == RpcClassId.flow) {
173+
listOf<FirTypeRef>(extracted.values.toList()[0]!!)
174+
} else {
175+
emptyList()
176+
}
177+
178+
serializablePropertiesProvider.getSerializablePropertiesForClass(symbol)
158179
.serializableProperties
159180
.mapNotNull { property ->
160-
vsApi {
161-
property.propertySymbol.resolvedReturnType.toClassSymbolVS(context.session)
181+
val resolvedTypeRef = property.propertySymbol.resolvedReturnTypeRef
182+
val result = if (resolvedTypeRef.toClassLikeSymbol(context.session) != null) {
183+
resolvedTypeRef
184+
} else {
185+
extracted[property.propertySymbol.resolvedReturnType]
186+
}
187+
if (result == null) {
188+
print(1)
162189
}
163-
}.forEach { symbol ->
190+
result
191+
}.memoryOptimizedPlus(flowProps)
192+
.forEach { symbol ->
164193
checkSerializableTypes(
165194
context = context,
166-
clazz = symbol,
195+
typeRef = symbol,
167196
serializablePropertiesProvider = serializablePropertiesProvider,
168197
parentContext = nextContext,
169198
checker = checker,

tests/compiler-plugin-tests/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ tasks.test {
109109
dependsOn(tasks.getByName("jar"))
110110
dependsOn(project(":core").tasks.getByName("jvmJar"))
111111
dependsOn(project(":utils").tasks.getByName("jvmJar"))
112+
dependsOn(project(":krpc:krpc-core").tasks.getByName("jvmJar"))
112113

113114
useJUnitPlatform()
114115

tests/compiler-plugin-tests/src/test-gen/kotlinx/rpc/codegen/test/runners/DiagnosticTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,10 @@ public void testCheckedAnnotation() {
3535
public void testRpcChecked() {
3636
runTest("src/testData/diagnostics/rpcChecked.kt");
3737
}
38+
39+
@Test
40+
@TestMetadata("strictMode.kt")
41+
public void testStrictMode() {
42+
runTest("src/testData/diagnostics/strictMode.kt");
43+
}
3844
}

tests/compiler-plugin-tests/src/test/kotlin/kotlinx/rpc/codegen/test/services/ExtensionRegistrarConfigurator.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,46 @@
44

55
package kotlinx.rpc.codegen.test.services
66

7+
import kotlinx.rpc.codegen.StrictMode
8+
import kotlinx.rpc.codegen.StrictModeConfigurationKeys
79
import kotlinx.rpc.codegen.registerRpcExtensions
10+
import kotlinx.rpc.codegen.toStrictMode
811
import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar
912
import org.jetbrains.kotlin.config.CompilerConfiguration
13+
import org.jetbrains.kotlin.test.directives.model.DirectiveApplicability
14+
import org.jetbrains.kotlin.test.directives.model.DirectivesContainer
15+
import org.jetbrains.kotlin.test.directives.model.SimpleDirectivesContainer
1016
import org.jetbrains.kotlin.test.model.TestModule
1117
import org.jetbrains.kotlin.test.services.EnvironmentConfigurator
1218
import org.jetbrains.kotlin.test.services.TestServices
1319
import org.jetbrains.kotlinx.serialization.compiler.extensions.SerializationComponentRegistrar
1420

1521
class ExtensionRegistrarConfigurator(testServices: TestServices) : EnvironmentConfigurator(testServices) {
22+
override val directiveContainers: List<DirectivesContainer> = listOf(RpcDirectives)
23+
1624
override fun CompilerPluginRegistrar.ExtensionStorage.registerCompilerExtensions(
1725
module: TestModule,
1826
configuration: CompilerConfiguration
1927
) {
28+
val strictMode = module.directives[RpcDirectives.RPC_STRICT_MODE]
29+
if (strictMode.isNotEmpty()) {
30+
val mode = StrictMode.fromCli(strictMode.single()) ?: StrictMode.WARNING
31+
configuration.put(StrictModeConfigurationKeys.STATE_FLOW, mode)
32+
configuration.put(StrictModeConfigurationKeys.SHARED_FLOW, mode)
33+
configuration.put(StrictModeConfigurationKeys.NESTED_FLOW, mode)
34+
configuration.put(StrictModeConfigurationKeys.STREAM_SCOPED_FUNCTIONS, mode)
35+
configuration.put(StrictModeConfigurationKeys.SUSPENDING_SERVER_STREAMING, mode)
36+
configuration.put(StrictModeConfigurationKeys.NOT_TOP_LEVEL_SERVER_FLOW, mode)
37+
configuration.put(StrictModeConfigurationKeys.FIELDS, mode)
38+
}
39+
2040
registerRpcExtensions(configuration)
2141

2242
// libs
2343
SerializationComponentRegistrar.registerExtensions(this)
2444
}
2545
}
46+
47+
object RpcDirectives : SimpleDirectivesContainer() {
48+
val RPC_STRICT_MODE by stringDirective("none, warning or error", DirectiveApplicability.Module)
49+
}

tests/compiler-plugin-tests/src/test/kotlin/kotlinx/rpc/codegen/test/services/RpcClasspathProviders.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ private class RuntimeDependency(
2828

2929
private object RpcClasspathProvider {
3030
private val TEST_RUNTIME = RuntimeDependency("build/libs/", "compiler-plugin-test")
31+
private val KRPC_CORE_JVM = RuntimeDependency("$globalRootDir/krpc/krpc-core/build/libs/", "krpc-core-jvm")
3132
private val CORE_JVM = RuntimeDependency("$globalRootDir/core/build/libs/", "core-jvm")
3233
private val UTILS_JVM = RuntimeDependency("$globalRootDir/utils/build/libs/", "utils-jvm")
3334

@@ -41,6 +42,7 @@ private object RpcClasspathProvider {
4142
val additionalDependencies = listOf(
4243
TEST_RUNTIME,
4344
CORE_JVM,
45+
KRPC_CORE_JVM,
4446
UTILS_JVM,
4547
).map { it.getFile(testServices) }
4648

0 commit comments

Comments
 (0)