Skip to content

Commit 0f7fcbf

Browse files
committed
Added diagnostic implementations for strict mode
1 parent 071df9d commit 0f7fcbf

File tree

5 files changed

+190
-4
lines changed

5 files changed

+190
-4
lines changed

compiler-plugin/compiler-plugin-common/src/main/core/kotlinx/rpc/codegen/common/Names.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.rpc.codegen.common
66

7+
import org.jetbrains.kotlin.name.CallableId
78
import org.jetbrains.kotlin.name.ClassId
89
import org.jetbrains.kotlin.name.FqName
910
import org.jetbrains.kotlin.name.Name
@@ -21,6 +22,16 @@ object RpcClassId {
2122
val stateFlow = ClassId(FqName("kotlinx.coroutines.flow"), Name.identifier("StateFlow"))
2223
}
2324

25+
object RpcCallableId {
26+
val streamScoped = CallableId(FqName("kotlinx.rpc.krpc"), Name.identifier("streamScoped"))
27+
val withStreamScope = CallableId(FqName("kotlinx.rpc.krpc"), Name.identifier("withStreamScope"))
28+
val StreamScope = CallableId(FqName("kotlinx.rpc.krpc"), Name.identifier("StreamScope"))
29+
val invokeOnStreamScopeCompletion = CallableId(
30+
FqName("kotlinx.rpc.krpc"),
31+
Name.identifier("invokeOnStreamScopeCompletion"),
32+
)
33+
}
34+
2435
object RpcNames {
2536
val SERVICE_STUB_NAME: Name = Name.identifier("\$rpcServiceStub")
2637

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/checkers/FirRpcCheckers.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ import org.jetbrains.kotlin.fir.analysis.checkers.expression.ExpressionCheckers
1414
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
1515

1616
class FirRpcDeclarationCheckers(ctx: FirCheckersContext) : DeclarationCheckers() {
17-
override val regularClassCheckers: Set<FirRegularClassChecker> = setOf(
17+
override val regularClassCheckers: Set<FirRegularClassChecker> = setOfNotNull(
1818
FirRpcAnnotationChecker(ctx),
19+
if (ctx.serializationIsPresent) FirRpcStrictModeClassChecker(ctx) else null,
1920
)
2021

2122
override val classCheckers: Set<FirClassChecker> = setOf(
@@ -34,5 +35,6 @@ class FirRpcDeclarationCheckers(ctx: FirCheckersContext) : DeclarationCheckers()
3435
class FirRpcExpressionCheckers(ctx: FirCheckersContext) : ExpressionCheckers() {
3536
override val functionCallCheckers: Set<FirFunctionCallChecker> = setOf(
3637
FirCheckedAnnotationFunctionCallChecker(ctx),
38+
FirRpcStrictModeExpressionChecker(ctx),
3739
)
3840
}
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.codegen.checkers
6+
7+
import kotlinx.rpc.codegen.FirCheckersContext
8+
import kotlinx.rpc.codegen.FirRpcPredicates
9+
import kotlinx.rpc.codegen.checkers.diagnostics.FirRpcStrictModeDiagnostics
10+
import kotlinx.rpc.codegen.common.RpcCallableId
11+
import kotlinx.rpc.codegen.common.RpcClassId
12+
import kotlinx.rpc.codegen.vsApi
13+
import org.jetbrains.kotlin.KtSourceElement
14+
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
15+
import org.jetbrains.kotlin.diagnostics.KtDiagnosticFactory0
16+
import org.jetbrains.kotlin.diagnostics.reportOn
17+
import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
18+
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
19+
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirClassChecker
20+
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
21+
import org.jetbrains.kotlin.fir.declarations.FirClass
22+
import org.jetbrains.kotlin.fir.declarations.FirProperty
23+
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
24+
import org.jetbrains.kotlin.fir.declarations.utils.isSuspend
25+
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
26+
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
27+
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
28+
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
29+
import org.jetbrains.kotlin.fir.types.coneType
30+
import org.jetbrains.kotlin.utils.memoryOptimizedMap
31+
import org.jetbrains.kotlin.utils.memoryOptimizedPlus
32+
import org.jetbrains.kotlinx.serialization.compiler.fir.services.FirSerializablePropertiesProvider
33+
import org.jetbrains.kotlinx.serialization.compiler.fir.services.serializablePropertiesProvider
34+
35+
class FirRpcStrictModeExpressionChecker(
36+
private val ctx: FirCheckersContext,
37+
) : FirFunctionCallChecker(MppCheckerKind.Common) {
38+
private val streamScopeFunctions = setOf(
39+
RpcCallableId.StreamScope,
40+
RpcCallableId.streamScoped,
41+
RpcCallableId.withStreamScope,
42+
RpcCallableId.invokeOnStreamScopeCompletion,
43+
)
44+
45+
override fun check(
46+
expression: FirFunctionCall,
47+
context: CheckerContext,
48+
reporter: DiagnosticReporter,
49+
) {
50+
expression.calleeReference.toResolvedCallableSymbol()?.let { symbol ->
51+
if (symbol.callableId in streamScopeFunctions) {
52+
ctx.strictModeDiagnostics.STREAM_SCOPE_FUNCTION_IN_RPC?.let {
53+
reporter.reportOn(expression.calleeReference.source, it, context)
54+
}
55+
}
56+
}
57+
}
58+
}
59+
60+
class FirRpcStrictModeClassChecker(private val ctx: FirCheckersContext) : FirClassChecker(MppCheckerKind.Common) {
61+
override fun check(
62+
declaration: FirClass,
63+
context: CheckerContext,
64+
reporter: DiagnosticReporter,
65+
) {
66+
if (!context.session.predicateBasedProvider.matches(FirRpcPredicates.rpc, declaration)) {
67+
return
68+
}
69+
70+
val serializablePropertiesProvider = context.session.serializablePropertiesProvider
71+
declaration.declarations.forEach { declaration ->
72+
when (declaration) {
73+
is FirProperty -> {
74+
ctx.strictModeDiagnostics.FIELD_IN_RPC_SERVICE?.let {
75+
reporter.reportOn(declaration.source, it, context)
76+
}
77+
}
78+
79+
is FirSimpleFunction -> {
80+
checkFunction(declaration, context, reporter, serializablePropertiesProvider)
81+
}
82+
83+
else -> {}
84+
}
85+
}
86+
}
87+
88+
private fun checkFunction(
89+
function: FirSimpleFunction,
90+
context: CheckerContext,
91+
reporter: DiagnosticReporter,
92+
serializablePropertiesProvider: FirSerializablePropertiesProvider,
93+
) {
94+
fun reportOn(element: KtSourceElement?, checker: FirRpcStrictModeDiagnostics.() -> KtDiagnosticFactory0?) {
95+
reporter.reportOn(element, ctx.strictModeDiagnostics.checker() ?: return, context)
96+
}
97+
98+
val returnClassSymbol = vsApi {
99+
function.returnTypeRef.coneType.toClassSymbolVS(context.session)
100+
}
101+
102+
val types = function.valueParameters.memoryOptimizedMap { parameter ->
103+
parameter.source to vsApi {
104+
parameter.returnTypeRef.coneType.toClassSymbolVS(context.session)
105+
}
106+
} memoryOptimizedPlus (function.returnTypeRef.source to returnClassSymbol)
107+
108+
types.filter { (_, symbol) ->
109+
symbol != null
110+
}.forEach { (source, symbol) ->
111+
checkSerializableTypes<FirClassSymbol<*>>(
112+
context = context,
113+
clazz = symbol!!,
114+
serializablePropertiesProvider = serializablePropertiesProvider,
115+
) { symbol, parents ->
116+
when (symbol.classId) {
117+
RpcClassId.stateFlow -> {
118+
reportOn(source) { STATE_FLOW_IN_RPC_SERVICE }
119+
}
120+
121+
RpcClassId.sharedFlow -> {
122+
reportOn(source) { SHARED_FLOW_IN_RPC_SERVICE }
123+
}
124+
125+
RpcClassId.flow -> {
126+
if (parents.any { it.classId == RpcClassId.flow }) {
127+
reportOn(source) { NESTED_STREAMING_IN_RPC_SERVICE }
128+
} else if (parents.isNotEmpty() && parents[0] == returnClassSymbol) {
129+
reportOn(source) { NON_TOP_LEVEL_SERVER_STREAMING_IN_RPC_SERVICE }
130+
}
131+
}
132+
}
133+
134+
symbol
135+
}
136+
}
137+
138+
if (returnClassSymbol?.classId == RpcClassId.flow && function.isSuspend) {
139+
reportOn(function.source) { SUSPENDING_SERVER_STREAMING_IN_RPC_SERVICE }
140+
}
141+
}
142+
143+
private fun <ContextElement> checkSerializableTypes(
144+
context: CheckerContext,
145+
clazz: FirClassSymbol<*>,
146+
serializablePropertiesProvider: FirSerializablePropertiesProvider,
147+
parentContext: List<ContextElement> = emptyList(),
148+
checker: (FirClassSymbol<*>, List<ContextElement>) -> ContextElement?,
149+
) {
150+
val newElement = checker(clazz, parentContext)
151+
val nextContext = if (newElement != null) {
152+
parentContext memoryOptimizedPlus newElement
153+
} else {
154+
parentContext
155+
}
156+
157+
serializablePropertiesProvider.getSerializablePropertiesForClass(clazz)
158+
.serializableProperties
159+
.mapNotNull { property ->
160+
vsApi {
161+
property.propertySymbol.resolvedReturnType.toClassSymbolVS(context.session)
162+
}
163+
}.forEach { symbol ->
164+
checkSerializableTypes(
165+
context = context,
166+
clazz = symbol,
167+
serializablePropertiesProvider = serializablePropertiesProvider,
168+
parentContext = nextContext,
169+
checker = checker,
170+
)
171+
}
172+
}
173+
}

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/checkers/diagnostics/FirRpcDiagnostics.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class FirRpcStrictModeDiagnostics(val modes: StrictModeAggregator) {
3737
val STATE_FLOW_IN_RPC_SERVICE by modded0<PsiElement>(modes.stateFlow)
3838
val SHARED_FLOW_IN_RPC_SERVICE by modded0<PsiElement>(modes.sharedFlow)
3939
val NESTED_STREAMING_IN_RPC_SERVICE by modded0<PsiElement>(modes.nestedFlow)
40-
val STREAM_SCOPE_ENTITY_IN_RPC by modded0<PsiElement>(modes.streamScopedFunctions)
40+
val STREAM_SCOPE_FUNCTION_IN_RPC by modded0<PsiElement>(modes.streamScopedFunctions)
4141
val SUSPENDING_SERVER_STREAMING_IN_RPC_SERVICE by modded0<PsiElement>(modes.suspendingServerStreaming)
4242
val NON_TOP_LEVEL_SERVER_STREAMING_IN_RPC_SERVICE by modded0<PsiElement>(modes.notTopLevelServerFlow)
4343
val FIELD_IN_RPC_SERVICE by modded0<PsiElement>(modes.fields)

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/checkers/diagnostics/RpcDiagnosticRendererFactory.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ class RpcStrictModeDiagnosticRendererFactory(
6565
)
6666
}
6767

68-
diagnostics.STREAM_SCOPE_ENTITY_IN_RPC?.let {
68+
diagnostics.STREAM_SCOPE_FUNCTION_IN_RPC?.let {
6969
put(
7070
factory = it,
71-
message = message("Stream Scope usage") { streamScopedFunctions },
71+
message = message("Stream scope usage") { streamScopedFunctions },
7272
)
7373
}
7474

0 commit comments

Comments
 (0)