@@ -7,6 +7,7 @@ package kotlinx.rpc.codegen.extension
77import kotlinx.rpc.codegen.VersionSpecificApi
88import kotlinx.rpc.codegen.common.RpcClassId
99import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
10+ import org.jetbrains.kotlin.backend.common.lower.irThrow
1011import org.jetbrains.kotlin.backend.jvm.functionByName
1112import org.jetbrains.kotlin.cli.common.messages.MessageCollector
1213import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
@@ -1070,7 +1071,7 @@ internal class RpcStubGenerator(
10701071 * Where:
10711072 * - `<callable-name-k>` - the name of the k-th callable in the service
10721073 */
1073- private fun irMethodDescriptorMap (resolver : IrValueParameter ): IrCallImpl {
1074+ private fun IrBlockBodyBuilder. irMethodDescriptorMap (resolver : IrValueParameter ): IrCallImpl {
10741075 return irMapOf(
10751076 keyType = ctx.irBuiltIns.stringType,
10761077 valueType = ctx.grpcPlatformMethodDescriptor.starProjectedType,
@@ -1149,7 +1150,10 @@ internal class RpcStubGenerator(
11491150 * MethodType.CLIENT_STREAMING, MethodType.BIDI_STREAMING
11501151 * - <request-codec>/<response-codec> - a MessageCodec getter, see [irCodec]
11511152 */
1152- private fun irMethodDescriptor (callable : ServiceDeclaration .Callable , resolver : IrValueParameter ): IrCall {
1153+ private fun IrBlockBodyBuilder.irMethodDescriptor (
1154+ callable : ServiceDeclaration .Callable ,
1155+ resolver : IrValueParameter ,
1156+ ): IrCall {
11531157 check(callable is ServiceDeclaration .Method ) {
11541158 " Only methods are allowed here"
11551159 }
@@ -1167,13 +1171,10 @@ internal class RpcStubGenerator(
11671171 val methodDescriptorType = ctx.grpcPlatformMethodDescriptor.typeWith(requestType, responseType)
11681172
11691173 return vsApi {
1170- IrCallImplVS (
1171- startOffset = UNDEFINED_OFFSET ,
1172- endOffset = UNDEFINED_OFFSET ,
1174+ irCall(
11731175 type = methodDescriptorType,
1174- symbol = ctx.functions.methodDescriptor,
1176+ callee = ctx.functions.methodDescriptor,
11751177 typeArgumentsCount = 2 ,
1176- valueArgumentsCount = 8 ,
11771178 )
11781179 }.apply {
11791180 arguments {
@@ -1234,13 +1235,13 @@ internal class RpcStubGenerator(
12341235 }
12351236
12361237 /* *
1237- * If [type ] is annotated with [RpcIrContext.withCodecAnnotation],
1238+ * If [messageType ] is annotated with [RpcIrContext.withCodecAnnotation],
12381239 * we use its codec object
12391240 *
1240- * If not, use [resolver].resolve ()
1241+ * If not, use [resolver].resolveOrNull ()
12411242 */
1242- private fun irCodec (type : IrType , resolver : IrValueParameter ): IrExpression {
1243- val owner = type .classOrFail.owner
1243+ private fun IrBlockBodyBuilder. irCodec (messageType : IrType , resolver : IrValueParameter ): IrExpression {
1244+ val owner = messageType .classOrFail.owner
12441245 val protobufMessage = owner.getAnnotation(ctx.withCodecAnnotation.owner.kotlinFqName)
12451246
12461247 return if (protobufMessage != null ) {
@@ -1256,14 +1257,12 @@ internal class RpcStubGenerator(
12561257 symbol = codec.classOrFail,
12571258 )
12581259 } else {
1259- vsApi {
1260- IrCallImplVS (
1261- startOffset = UNDEFINED_OFFSET ,
1262- endOffset = UNDEFINED_OFFSET ,
1263- type = ctx.grpcMessageCodec.typeWith(type),
1264- symbol = ctx.functions.grpcMessageCodecResolverResolveOrNull.symbol,
1260+ val codecType = ctx.grpcMessageCodec.typeWith(messageType)
1261+ val codecCall = vsApi {
1262+ irCall(
1263+ type = codecType.makeNullable(),
1264+ callee = ctx.functions.grpcMessageCodecResolverResolveOrNull.symbol,
12651265 typeArgumentsCount = 0 ,
1266- valueArgumentsCount = 1 ,
12671266 )
12681267 }.apply {
12691268 arguments {
@@ -1275,10 +1274,21 @@ internal class RpcStubGenerator(
12751274 )
12761275
12771276 values {
1278- + irTypeOfCall(type )
1277+ + irTypeOfCall(messageType )
12791278 }
12801279 }
12811280 }
1281+
1282+ irElvis(
1283+ expression = codecCall,
1284+ ifNull = irCall(ctx.irBuiltIns.illegalArgumentExceptionSymbol).apply {
1285+ arguments {
1286+ values {
1287+ + stringConst(" No codec found for ${messageType.classFqName} " )
1288+ }
1289+ }
1290+ },
1291+ )
12821292 }
12831293 }
12841294
@@ -1795,4 +1805,44 @@ internal class RpcStubGenerator(
17951805
17961806 fun IrBuilderWithScope.irSafeAs (argument : IrExpression , type : IrType ) =
17971807 IrTypeOperatorCallImpl (startOffset, endOffset, type, IrTypeOperator .SAFE_CAST , type, argument)
1808+
1809+ fun IrBlockBodyBuilder.irElvis (expression : IrExpression , ifNull : IrExpression ): IrExpression {
1810+ check(expression.type == ifNull.type || ifNull.type == ctx.irBuiltIns.nothingType) {
1811+ " Type mismatch: ${expression.type.dumpKotlinLike()} != ${ifNull.type.dumpKotlinLike()} "
1812+ }
1813+ // BLOCK type=kotlin.Int origin=ELVIS
1814+ // VAR IR_TEMPORARY_VARIABLE name:tmp_0 type:kotlin.Int? [val]
1815+ // GET_VAR 'val some: kotlin.Int? declared in <root>.test' type=kotlin.Int? origin=null
1816+ // WHEN type=kotlin.Int origin=null
1817+ // BRANCH
1818+ // if: CALL 'public final fun EQEQ (arg0: kotlin.Any?, arg1: kotlin.Any?): kotlin.Boolean declared in kotlin.internal.ir' type=kotlin.Boolean origin=EQEQ
1819+ // ARG arg0: GET_VAR 'val tmp_0: kotlin.Int? declared in <root>.test' type=kotlin.Int? origin=null
1820+ // ARG arg1: CONST Null type=kotlin.Nothing? value=null
1821+ // then: THROW type=kotlin.Nothing
1822+ // CONSTRUCTOR_CALL 'public constructor <init> (p0: @[FlexibleNullability] kotlin.String?) declared in java.lang.IllegalStateException' type=java.lang.IllegalStateException origin=null
1823+ // ARG p0: CONST String type=kotlin.String value="some is null"
1824+ // BRANCH
1825+ // if: CONST Boolean type=kotlin.Boolean value=true
1826+ // then: GET_VAR 'val tmp_0: kotlin.Int? declared in <root>.test' type=kotlin.Int? origin=null
1827+ return irBlock(origin = IrStatementOrigin .ELVIS , resultType = expression.type.makeNotNull()) {
1828+ val temp = irTemporary(
1829+ value = expression,
1830+ nameHint = " elvis_left_hand_side" ,
1831+ isMutable = false ,
1832+ )
1833+ + irWhen(
1834+ type = expression.type,
1835+ branches = listOf (
1836+ irBranch(
1837+ condition = irEqualsNull(irGet(temp)),
1838+ result = ifNull,
1839+ ),
1840+ irBranch(
1841+ condition = irTrue(),
1842+ result = irGet(temp),
1843+ ),
1844+ ),
1845+ )
1846+ }
1847+ }
17981848}
0 commit comments