Skip to content

Commit 51862c2

Browse files
committed
grpc: Generate model from Descriptor instead of DescriptorProto
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 56372ad commit 51862c2

File tree

9 files changed

+456
-29
lines changed

9 files changed

+456
-29
lines changed

grpc/grpc-core/src/commonTest/proto/all_primitives.proto

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ message AllPrimitivesCommon {
99
int64 int64 = 4;
1010
uint32 uint32 = 5;
1111
uint64 uint64 = 6;
12-
sint32 sint32 = 7;
13-
sint64 sint64 = 8;
14-
fixed32 fixed32 = 9;
15-
fixed64 fixed64 = 10;
16-
sfixed32 sfixed32 = 11;
17-
sfixed64 sfixed64 = 12;
18-
bool bool = 13;
19-
string string = 14;
20-
bytes bytes = 15;
12+
optional sint32 sint32 = 7;
13+
optional sint64 sint64 = 8;
14+
optional fixed32 fixed32 = 9;
15+
optional fixed64 fixed64 = 10;
16+
optional sfixed32 sfixed32 = 11;
17+
optional sfixed64 sfixed64 = 12;
18+
optional bool bool = 13;
19+
optional string string = 14;
20+
optional bytes bytes = 15;
2121
}

grpc/grpc-core/src/commonTest/proto/repeated.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ message RepeatedCommon {
66
repeated fixed32 listFixed32 = 1 [packed = true];
77
repeated int32 listInt32 = 2 [packed = false];
88
repeated string listString = 3;
9+
10+
message InnerClass {
11+
12+
}
913
}

grpc/grpc-core/src/jvmTest/proto/repeated.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ message Repeated {
1010
repeated string listString = 3;
1111
repeated References listReference = 4;
1212
}
13+
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.protobuf
6+
7+
import com.google.protobuf.DescriptorProtos
8+
import com.google.protobuf.Descriptors
9+
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest
10+
import kotlinx.rpc.protobuf.model.*
11+
12+
private val modelCache = mutableMapOf<Descriptors.GenericDescriptor, Any>()
13+
14+
fun CodeGeneratorRequest.toCommonModel(): Model {
15+
val protoFileMap = protoFileList.associateBy { it.name }
16+
val fileDescriptors = mutableMapOf<String, Descriptors.FileDescriptor>()
17+
18+
val files = protoFileList.map { protoFile -> protoFile.toDescriptor(protoFileMap, fileDescriptors) }
19+
20+
return Model(
21+
files = files.map { it.toCommonModel() }
22+
)
23+
}
24+
25+
private inline fun <D, reified T> D.cached(block: (D) -> T): T
26+
where D : Descriptors.GenericDescriptor, T : Any {
27+
if (modelCache.containsKey(this)) {
28+
return modelCache[this] as T
29+
}
30+
val declaration = block(this)
31+
modelCache[this] = declaration
32+
return declaration
33+
}
34+
35+
private fun Descriptors.FileDescriptor.toCommonModel(): FileDeclaration = cached {
36+
return FileDeclaration(
37+
name = kotlinFileName(),
38+
packageName = FqName.Package.fromString(`package`),
39+
dependencies = dependencies.map { it.toCommonModel() },
40+
messageDeclarations = messageTypes.map { it.toCommonModel() },
41+
enumDeclarations = enumTypes.map { it.toCommonModel() },
42+
serviceDeclarations = services.map { it.toCommonModel() },
43+
deprecated = options.deprecated,
44+
doc = null,
45+
)
46+
}
47+
48+
private fun Descriptors.Descriptor.toCommonModel(): MessageDeclaration = cached {
49+
val regularFields = fields.filter { field -> field.realContainingOneof == null }.map { it.toCommonModel() }
50+
51+
return MessageDeclaration(
52+
outerClassName = fqName(),
53+
name = fqName(),
54+
actualFields = regularFields,
55+
// get all oneof declarations that are not created from an optional in proto3 https://github.com/googleapis/api-linter/issues/1323
56+
oneOfDeclarations = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toCommonModel() },
57+
enumDeclarations = enumTypes.map { it.toCommonModel() },
58+
nestedDeclarations = nestedTypes.map { it.toCommonModel() },
59+
deprecated = options.deprecated,
60+
doc = null,
61+
)
62+
}
63+
64+
private fun Descriptors.FieldDescriptor.toCommonModel(): FieldDeclaration = cached {
65+
toProto().hasProto3Optional()
66+
return FieldDeclaration(
67+
name = fqName().simpleName,
68+
number = number,
69+
type = modelType(),
70+
nullable = isNullable(),
71+
deprecated = options.deprecated,
72+
doc = null,
73+
proto = toProto(),
74+
)
75+
}
76+
77+
private fun Descriptors.FieldDescriptor.isNullable(): Boolean {
78+
// aligns with edition settings and backward compatibility with proto2 and proto3
79+
return hasPresence() && !isRequired && !hasDefaultValue()
80+
}
81+
82+
private fun Descriptors.OneofDescriptor.toCommonModel(): OneOfDeclaration = cached {
83+
return OneOfDeclaration(
84+
name = fqName(),
85+
variants = fields.map { it.toCommonModel() },
86+
descriptor = this
87+
)
88+
}
89+
90+
private fun Descriptors.EnumDescriptor.toCommonModel(): EnumDeclaration = cached {
91+
val entriesMap = mutableMapOf<Int, EnumDeclaration.Entry>()
92+
val aliases = mutableListOf<EnumDeclaration.Alias>()
93+
94+
values.forEach { value ->
95+
if (entriesMap.containsKey(value.number)) {
96+
val original = entriesMap.getValue(value.number)
97+
aliases.add(value.toAliasModel(original))
98+
} else {
99+
entriesMap[value.number] = value.toCommonModel()
100+
}
101+
}
102+
103+
if (!options.allowAlias && aliases.isNotEmpty()) {
104+
error("Enum ${fullName} has aliases: ${aliases.joinToString { it.name.simpleName }}")
105+
}
106+
107+
return EnumDeclaration(
108+
outerClassName = fqName(),
109+
name = fqName(),
110+
originalEntries = entriesMap.values.toList(),
111+
aliases = aliases,
112+
deprecated = options.deprecated,
113+
doc = null
114+
)
115+
}
116+
117+
private fun Descriptors.EnumValueDescriptor.toCommonModel(): EnumDeclaration.Entry = cached {
118+
return EnumDeclaration.Entry(
119+
name = fqName(),
120+
deprecated = options.deprecated,
121+
doc = null,
122+
)
123+
}
124+
125+
// no caching, as it would conflict with .toModel
126+
private fun Descriptors.EnumValueDescriptor.toAliasModel(original: EnumDeclaration.Entry): EnumDeclaration.Alias {
127+
return EnumDeclaration.Alias(
128+
name = fqName(),
129+
original = original,
130+
deprecated = options.deprecated,
131+
doc = null,
132+
)
133+
}
134+
135+
private fun Descriptors.ServiceDescriptor.toCommonModel(): ServiceDeclaration = cached {
136+
return ServiceDeclaration(
137+
name = fqName(),
138+
methods = methods.map { it.toCommonModel() }
139+
)
140+
}
141+
142+
private fun Descriptors.MethodDescriptor.toCommonModel(): MethodDeclaration = cached {
143+
return MethodDeclaration(
144+
name = name,
145+
clientStreaming = isClientStreaming,
146+
serverStreaming = isServerStreaming,
147+
inputType = lazy { inputType.toCommonModel() },
148+
outputType = lazy { outputType.toCommonModel() }
149+
)
150+
}
151+
152+
153+
private fun DescriptorProtos.FileDescriptorProto.toDescriptor(
154+
protoFileMap: Map<String, DescriptorProtos.FileDescriptorProto>,
155+
cache: MutableMap<String, Descriptors.FileDescriptor>
156+
): Descriptors.FileDescriptor {
157+
if (cache.containsKey(name)) return cache[name]!!
158+
159+
val dependencies = dependencyList.map { depName ->
160+
val depProto = protoFileMap[depName] ?: error("Missing dependency: $depName")
161+
depProto.toDescriptor(protoFileMap, cache)
162+
}.toTypedArray()
163+
164+
val fileDescriptor = Descriptors.FileDescriptor.buildFrom(this, dependencies)
165+
cache[name] = fileDescriptor
166+
return fileDescriptor
167+
}
168+
169+
//// Type Conversion Extension ////
170+
171+
private fun Descriptors.FieldDescriptor.modelType(): FieldType {
172+
val baseType = when (type) {
173+
Descriptors.FieldDescriptor.Type.DOUBLE -> FieldType.IntegralType.DOUBLE
174+
Descriptors.FieldDescriptor.Type.FLOAT -> FieldType.IntegralType.FLOAT
175+
Descriptors.FieldDescriptor.Type.INT64 -> FieldType.IntegralType.INT64
176+
Descriptors.FieldDescriptor.Type.UINT64 -> FieldType.IntegralType.UINT64
177+
Descriptors.FieldDescriptor.Type.INT32 -> FieldType.IntegralType.INT32
178+
Descriptors.FieldDescriptor.Type.FIXED64 -> FieldType.IntegralType.FIXED64
179+
Descriptors.FieldDescriptor.Type.FIXED32 -> FieldType.IntegralType.FIXED32
180+
Descriptors.FieldDescriptor.Type.BOOL -> FieldType.IntegralType.BOOL
181+
Descriptors.FieldDescriptor.Type.STRING -> FieldType.IntegralType.STRING
182+
Descriptors.FieldDescriptor.Type.BYTES -> FieldType.IntegralType.BYTES
183+
Descriptors.FieldDescriptor.Type.UINT32 -> FieldType.IntegralType.UINT32
184+
Descriptors.FieldDescriptor.Type.SFIXED32 -> FieldType.IntegralType.SFIXED32
185+
Descriptors.FieldDescriptor.Type.SFIXED64 -> FieldType.IntegralType.SFIXED64
186+
Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32
187+
Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64
188+
Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Reference(lazy { enumType!!.toCommonModel().name })
189+
Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(lazy { messageType!!.toCommonModel().name })
190+
Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported")
191+
}
192+
193+
if (isRepeated) {
194+
return FieldType.List(baseType)
195+
}
196+
197+
// TODO: Handle map type
198+
199+
return baseType
200+
}
201+
202+
//// Utility Extensions ////
203+
204+
private fun Descriptors.FileDescriptor.kotlinFileName(): String {
205+
return "${protoFileNameToKotlinName()}.kt"
206+
}
207+
208+
private fun Descriptors.FileDescriptor.protoFileNameToKotlinName(): String {
209+
return name.removeSuffix(".proto").fullProtoNameToKotlin(firstLetterUpper = true)
210+
}
211+
212+
213+
private fun String.fullProtoNameToKotlin(firstLetterUpper: Boolean = false): String {
214+
val lastDelimiterIndex = indexOfLast { it == '.' || it == '/' }
215+
return if (lastDelimiterIndex != -1) {
216+
val packageName = substring(0, lastDelimiterIndex)
217+
val name = substring(lastDelimiterIndex + 1)
218+
val delimiter = this[lastDelimiterIndex]
219+
return "$packageName$delimiter${name.simpleProtoNameToKotlin(firstLetterUpper = true)}"
220+
} else {
221+
simpleProtoNameToKotlin(firstLetterUpper)
222+
}
223+
}
224+
225+
226+
private val snakeRegExp = "(_[a-z]|-[a-z])".toRegex()
227+
228+
private fun String.snakeToCamelCase(): String {
229+
return replace(snakeRegExp) { it.value.last().uppercase() }
230+
}
231+
232+
private fun String.simpleProtoNameToKotlin(firstLetterUpper: Boolean = false): String {
233+
return snakeToCamelCase().run {
234+
if (firstLetterUpper) {
235+
replaceFirstChar { it.uppercase() }
236+
} else {
237+
this
238+
}
239+
}
240+
}
241+

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,30 +239,34 @@ class ModelToKotlinCommonGenerator(
239239
val fieldName = field.name
240240
if (field.nullable) {
241241
scope("$fieldName?.also") {
242-
code(field.writeValue())
242+
code(field.writeValue("it"))
243243
}
244244
} else if (!field.hasPresence) {
245245
ifBranch(condition = field.defaultCheck(), ifBlock = {
246-
code(field.writeValue())
246+
code(field.writeValue(field.name))
247247
})
248248
} else {
249-
code(field.writeValue())
249+
code(field.writeValue(field.name))
250250
}
251251
}
252252
}
253253

254-
private fun FieldDeclaration.writeValue(): String {
254+
private fun FieldDeclaration.writeValue(variable: String): String {
255255
return when (val fieldType = type) {
256-
is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}($number, $name)"
256+
is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}($number, $variable)"
257257
is FieldType.List -> when {
258258
packed && packedFixedSize ->
259-
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $name)"
259+
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable)"
260260

261261
packed && !packedFixedSize ->
262-
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $name, ${wireSizeCall(name)})"
262+
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable, ${
263+
wireSizeCall(
264+
variable
265+
)
266+
})"
263267

264268
else ->
265-
"$name.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }"
269+
"$variable.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }"
266270
}
267271

268272
is FieldType.Map -> TODO()

0 commit comments

Comments
 (0)