Skip to content

Commit 9244646

Browse files
y9malyMr3zee
andauthored
fix wrong unchecked null cast (potential NPE) (#445)
Co-authored-by: Alexander Sysoev <[email protected]>
1 parent 6f0f3a6 commit 9244646

File tree

5 files changed

+58
-7
lines changed

5 files changed

+58
-7
lines changed

krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
172172

173173
private val serverSupportedPlugins: CompletableDeferred<Set<KrpcPlugin>> = CompletableDeferred()
174174

175-
private val requestChannels = RpcInternalConcurrentHashMap<String, Channel<Any?>>()
175+
private val requestChannels = RpcInternalConcurrentHashMap<String, Channel<Result<Any?>>>()
176176

177177
@InternalRpcApi
178178
final override val supportedPlugins: Set<KrpcPlugin>
@@ -247,11 +247,11 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
247247

248248
val callId = "$connectionId:${callable.name}:$id"
249249

250-
val channel = Channel<T>()
250+
val channel = Channel<Result<T>>()
251251

252252
try {
253253
@Suppress("UNCHECKED_CAST")
254-
requestChannels[callId] = channel as Channel<Any?>
254+
requestChannels[callId] = channel as Channel<Result<Any?>>
255255

256256
val request = serializeRequest(
257257
callId = callId,
@@ -308,7 +308,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
308308
}
309309
}
310310

311-
private suspend fun <T> FlowCollector<T>.consumeAndEmitServerMessages(channel: Channel<T>) {
311+
private suspend fun <T> FlowCollector<T>.consumeAndEmitServerMessages(channel: Channel<Result<T>>) {
312312
while (true) {
313313
val element = channel.receiveCatching()
314314
if (element.isClosed) {
@@ -317,14 +317,22 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
317317
}
318318

319319
if (!element.isFailure) {
320-
emit(element.getOrThrow())
320+
val result = element.getOrThrow()
321+
result.fold(
322+
onSuccess = { value ->
323+
emit(value)
324+
},
325+
onFailure = { throwable ->
326+
throw throwable
327+
}
328+
)
321329
}
322330
}
323331
}
324332

325333
private suspend fun <T, @Rpc R : Any> handleServerStreamingMessage(
326334
message: KrpcCallMessage,
327-
channel: Channel<T>,
335+
channel: Channel<Result<T>>,
328336
callable: RpcCallable<R>,
329337
) {
330338
when (message) {
@@ -355,7 +363,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
355363
}
356364

357365
@Suppress("UNCHECKED_CAST")
358-
channel.send(value.getOrNull() as T)
366+
channel.send(value as Result<T>)
359367
}
360368

361369
is KrpcCallMessage.StreamFinished -> {

krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestService.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ interface KrpcTestService {
6464
suspend fun nullableReturn(returnNull: Boolean): TestClass?
6565
suspend fun variance(arg2: TestList<in TestClass>, arg3: TestList2<TestClass>): TestList<out TestClass>?
6666
suspend fun collectOnce(flow: Flow<String>)
67+
suspend fun returnTestClassThatThrowsWhileDeserialization(value: Int): TestClassThatThrowsWhileDeserialization
6768

6869
suspend fun nonSerializableClass(localDate: LocalDate): LocalDate
6970
suspend fun nonSerializableClassWithSerializer(

krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ class KrpcTestServiceBackend : KrpcTestService {
9898
return arg1
9999
}
100100

101+
override suspend fun returnTestClassThatThrowsWhileDeserialization(value: Int): TestClassThatThrowsWhileDeserialization {
102+
return TestClassThatThrowsWhileDeserialization(value)
103+
}
104+
101105
override suspend fun nullableParam(arg1: String?): String {
102106
return arg1 ?: "null"
103107
}

krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import kotlinx.rpc.krpc.server.KrpcServer
2020
import kotlinx.rpc.registerService
2121
import kotlinx.rpc.withService
2222
import kotlinx.serialization.KSerializer
23+
import kotlinx.serialization.SerializationException
2324
import kotlinx.serialization.descriptors.PrimitiveKind
2425
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
2526
import kotlinx.serialization.descriptors.SerialDescriptor
@@ -460,6 +461,17 @@ abstract class KrpcTransportTestBase {
460461
fun testUnitFlow() = runTest {
461462
assertEquals(Unit, client.unitFlow().toList().single())
462463
}
464+
465+
@Test
466+
fun testPR445() = runTest {
467+
assertFailsWith<SerializationException> {
468+
val result = client.returnTestClassThatThrowsWhileDeserialization(42)
469+
@Suppress("SENSELESS_COMPARISON")
470+
if (result == null) {
471+
assertNotNull(result, "result must not be null")
472+
}
473+
}
474+
}
463475
}
464476

465477
private val JS_EXTENDED_TIMEOUT = if (isJs) 300.seconds else 60.seconds

krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/TestClass.kt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
package kotlinx.rpc.krpc.test
66

7+
import kotlinx.serialization.KSerializer
78
import kotlinx.serialization.Serializable
9+
import kotlinx.serialization.SerializationException
10+
import kotlinx.serialization.builtins.serializer
11+
import kotlinx.serialization.encoding.Decoder
12+
import kotlinx.serialization.encoding.Encoder
813

914
@Suppress("EqualsOrHashCode", "detekt.EqualsWithHashCodeExist")
1015
@Serializable
@@ -15,6 +20,27 @@ open class TestClass(val value: Int = 0) {
1520
}
1621
}
1722

23+
@Suppress("EqualsOrHashCode", "detekt.EqualsWithHashCodeExist")
24+
@Serializable(with = TestClassThatThrowsWhileDeserialization.Serializer::class)
25+
class TestClassThatThrowsWhileDeserialization(val value: Int = 0) {
26+
object Serializer : KSerializer<TestClassThatThrowsWhileDeserialization> {
27+
override val descriptor = Int.serializer().descriptor
28+
29+
override fun serialize(encoder: Encoder, value: TestClassThatThrowsWhileDeserialization) {
30+
encoder.encodeInt(value.value)
31+
}
32+
33+
override fun deserialize(decoder: Decoder): TestClassThatThrowsWhileDeserialization {
34+
throw SerializationException("Its TestClassThatThrowsWhileDeserialization")
35+
}
36+
}
37+
38+
override fun equals(other: Any?): Boolean {
39+
if (other !is TestClassThatThrowsWhileDeserialization) return false
40+
return value == other.value
41+
}
42+
}
43+
1844
@Serializable
1945
data class TestList<@Suppress("unused") T : TestClass>(val value: Int = 42)
2046

0 commit comments

Comments
 (0)