Skip to content

Commit 39403d0

Browse files
committed
grpc-pb: Check required fields ins submessages
Signed-off-by: Johannes Zottele <[email protected]>
1 parent d5b466f commit 39403d0

File tree

5 files changed

+74
-8
lines changed

5 files changed

+74
-8
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,30 @@ public abstract class InternalMessage(fieldsWithPresence: Int) {
1717
public abstract val _size: Int
1818
}
1919

20+
@InternalRpcApi
2021
public class MsgFieldDelegate<T : Any>(
2122
private val presenceIdx: Int? = null,
2223
private val defaultProvider: (() -> T)? = null
2324
) : ReadWriteProperty<InternalMessage, T> {
2425

2526
private var valueSet = false
26-
private var _value: T? = null
27+
private var value: T? = null
2728

2829
override operator fun getValue(thisRef: InternalMessage, property: KProperty<*>): T {
2930
if (!valueSet) {
3031
if (defaultProvider != null) {
31-
_value = defaultProvider.invoke()
32+
value = defaultProvider.invoke()
3233
valueSet = true
3334
} else {
3435
error("Property ${property.name} not initialized")
3536
}
3637
}
37-
return _value as T
38+
return value as T
3839
}
3940

4041
override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, value: T) {
4142
presenceIdx?.let { thisRef.presenceMask[it] = true }
42-
_value = value
43+
this@MsgFieldDelegate.value = value
4344
valueSet = true
4445
}
4546
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package kotlinx.rpc.grpc.pb
66

77
import OneOfMsg
88
import OneOfMsgInternal
9+
import OneOfWithRequired
910
import Outer
1011
import OuterInternal
1112
import asInternal
@@ -87,10 +88,20 @@ class ProtosTest {
8788
}
8889

8990
@Test
90-
fun testPresenceCheckProto() {
91+
fun testRepeatedWithRequiredSubField() {
92+
assertFailsWith<IllegalStateException> {
93+
RepeatedWithRequired {
94+
// we construct the message using the internal class,
95+
// so it is not invoking the checkRequired method on construction
96+
msgList = listOf(PresenceCheck { RequiredPresence = 2 }, PresenceCheckInternal())
97+
}
98+
}
99+
}
91100

101+
@Test
102+
fun testPresenceCheckProto() {
92103
// Check a missing required field in a user-constructed message
93-
assertFailsWith<IllegalStateException>("PresenceCheck is missing required field: RequiredPresence") {
104+
assertFailsWith<IllegalStateException> {
94105
PresenceCheck {}
95106
}
96107

@@ -100,7 +111,7 @@ class ProtosTest {
100111
encoder.writeFloat(2, 1f)
101112
encoder.flush()
102113

103-
assertFailsWith<IllegalStateException>("PresenceCheck is missing required field: RequiredPresence") {
114+
assertFailsWith<IllegalStateException> {
104115
PresenceCheckInternal.CODEC.decode(buffer)
105116
}
106117
}
@@ -217,6 +228,17 @@ class ProtosTest {
217228
assertEquals(OneOfMsg.Field.Fixed(123u), decoded.field)
218229
}
219230

231+
@Test
232+
fun testOneOfRequiredSubField() {
233+
assertFailsWith<IllegalStateException> {
234+
OneOfWithRequired {
235+
// we construct the message using the internal class,
236+
// so it is not invoking the checkRequired method on construction
237+
field = OneOfWithRequired.Field.Msg(PresenceCheckInternal())
238+
}
239+
}
240+
}
241+
220242
@Test
221243
fun testOneOfNull() {
222244
// write two values on the oneOf field.
@@ -239,7 +261,7 @@ class ProtosTest {
239261

240262
@Test
241263
fun testRecursiveReqNotSet() {
242-
assertFailsWith<IllegalStateException>("RecursiveReq is missing required field: rec") {
264+
assertFailsWith<IllegalStateException> {
243265
val msg = RecursiveReq {
244266
rec = RecursiveReq {
245267
rec = RecursiveReq {

grpc/grpc-core/src/commonTest/proto/oneof.proto

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import "sub_message.proto";
22
import "enum.proto";
3+
import "presence_check.proto";
34

45
message OneOfMsg {
56
oneof field {
@@ -8,4 +9,11 @@ message OneOfMsg {
89
test.submsg.Other other = 4;
910
kotlinx.rpc.grpc.test.MyEnum enum = 5;
1011
}
12+
}
13+
14+
message OneOfWithRequired {
15+
oneof field {
16+
int32 sint = 1;
17+
kotlinx.rpc.grpc.test.common.PresenceCheck msg = 2;
18+
}
1119
}

grpc/grpc-core/src/commonTest/proto/repeated.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
syntax = "proto3";
22

3+
import "presence_check.proto";
4+
35
package kotlinx.rpc.grpc.test.common;
46

57
message Repeated {
@@ -13,4 +15,8 @@ message Repeated {
1315
message Other {
1416
int32 a = 1;
1517
}
18+
}
19+
20+
message RepeatedWithRequired {
21+
repeated PresenceCheck msgList = 1;
1622
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,35 @@ class ModelToKotlinCommonGenerator(
530530
code("${field.name}.asInternal().checkRequiredFields()")
531531
})
532532
}
533+
534+
// check submessages in oneofs
535+
declaration.fields().filter { it.second.type is FieldType.OneOf }.forEach { (_, field) ->
536+
val oneOfType = field.type as FieldType.OneOf
537+
val messageVariants = oneOfType.dec.variants.filter { it.type is FieldType.Message }
538+
if (messageVariants.isEmpty()) return@forEach
539+
540+
scope("${field.name}?.also") {
541+
whenBlock {
542+
messageVariants.forEach { variant ->
543+
val variantClassName = "${field.type.dec.name.safeFullName()}.${variant.name}"
544+
whenCase("it is $variantClassName") {
545+
code("it.value.asInternal().checkRequiredFields()")
546+
}
547+
}
548+
}
549+
}
550+
}
551+
552+
// check submessages in lists
553+
declaration.fields().filter { it.second.type is FieldType.List }.forEach { (_, field) ->
554+
val listType = field.type as FieldType.List
555+
if (listType.value !is FieldType.Message) return@forEach
556+
557+
scope("${field.name}.forEach") {
558+
code("it.asInternal().checkRequiredFields()")
559+
}
560+
}
561+
533562
}
534563

535564
private fun CodeGenerator.generateInternalComputeSize(declaration: MessageDeclaration) {

0 commit comments

Comments
 (0)