Skip to content

Commit 1d071ef

Browse files
zhelenskiySpace Team
authored andcommitted
[IR] Move JvmUpgradeCallableReferences before JvmTailrecLowering
#KT-74383
1 parent 3d3d9cd commit 1d071ef

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/TailRecursionCallsCollector.kt

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ data class TailCalls(val ir: Set<IrCall>, val fromManyFunctions: Boolean)
3939
* It is also not guaranteed that each returned call is detected as tail recursion by the frontend.
4040
* However any returned call can be correctly optimized as tail recursion.
4141
*/
42-
fun collectTailRecursionCalls(irFunction: IrFunction, followFunctionReference: (IrFunctionReference) -> Boolean): TailCalls {
42+
fun collectTailRecursionCalls(
43+
irFunction: IrFunction,
44+
followFunctionReference: (IrFunctionReference) -> Boolean,
45+
followRichFunctionReference: (IrRichFunctionReference) -> Boolean
46+
): TailCalls {
4347
if ((irFunction as? IrSimpleFunction)?.isTailrec != true) {
4448
return TailCalls(emptySet(), false)
4549
}
@@ -155,6 +159,15 @@ fun collectTailRecursionCalls(irFunction: IrFunction, followFunctionReference: (
155159
expression.symbol.owner.body?.accept(this, VisitorState(isTailExpression = false, inOtherFunction = true))
156160
}
157161
}
162+
163+
override fun visitRichFunctionReference(expression: IrRichFunctionReference, data: VisitorState) {
164+
expression.acceptChildren(this, VisitorState(isTailExpression = false, data.inOtherFunction))
165+
if (followRichFunctionReference(expression)) {
166+
// If control reaches end of lambda, it will *not* end the current function by default,
167+
// so the lambda's body itself is not a tail statement.
168+
expression.invokeFunction.body?.accept(this, VisitorState(isTailExpression = false, inOtherFunction = true))
169+
}
170+
}
158171
}
159172

160173
irFunction.body?.accept(visitor, VisitorState(isTailExpression = true, inOtherFunction = false))

compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/TailrecLowering.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols
3232
import org.jetbrains.kotlin.ir.util.defaultValueForType
3333
import org.jetbrains.kotlin.ir.util.getArgumentsWithIr
3434
import org.jetbrains.kotlin.ir.util.patchDeclarationParents
35+
import org.jetbrains.kotlin.ir.util.transformInPlace
3536
import org.jetbrains.kotlin.ir.visitors.IrVisitorVoid
3637
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
3738
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
@@ -69,12 +70,14 @@ open class TailrecLowering(val context: LoweringContext) : BodyLoweringPass {
6970

7071
open fun followFunctionReference(reference: IrFunctionReference): Boolean = false
7172

73+
open fun followRichFunctionReference(reference: IrRichFunctionReference): Boolean = false
74+
7275
open fun nullConst(startOffset: Int, endOffset: Int, type: IrType): IrExpression =
7376
IrConstImpl.defaultValueForType(startOffset, endOffset, type)
7477
}
7578

7679
private fun TailrecLowering.lowerTailRecursionCalls(irFunction: IrFunction) {
77-
val (tailRecursionCalls, someCallsAreFromOtherFunctions) = collectTailRecursionCalls(irFunction, ::followFunctionReference)
80+
val (tailRecursionCalls, someCallsAreFromOtherFunctions) = collectTailRecursionCalls(irFunction, ::followFunctionReference, ::followRichFunctionReference)
7881
if (tailRecursionCalls.isEmpty()) {
7982
return
8083
}
@@ -152,6 +155,15 @@ private class BodyTransformer(
152155
return super.visitFunctionReference(expression)
153156
}
154157

158+
override fun visitRichFunctionReference(expression: IrRichFunctionReference): IrExpression {
159+
return if (lowering.followRichFunctionReference(expression)) {
160+
super.visitRichFunctionReference(expression)
161+
} else {
162+
expression.boundValues.transformInPlace(this, null)
163+
expression
164+
}
165+
}
166+
155167
private fun IrBuilderWithScope.genTailCall(expression: IrCall) = this.irBlock(expression) {
156168
// Get all specified arguments:
157169
val parameterToArgument = expression.getArgumentsWithIr().associateTo(mutableMapOf()) { (parameter, argument) ->

compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/JvmLoweringPhases.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ private val jvmFilePhases = createFilePhases(
6767
::JvmSingleAbstractMethodLowering,
6868
::JvmMultiFieldValueClassLowering,
6969
::JvmInlineClassLowering,
70+
::JvmUpgradeCallableReferences,
7071
::JvmTailrecLowering,
7172

7273
::MappedEnumWhenLowering,
7374

7475
::AssertionLowering,
7576
::JvmReturnableBlockLowering,
7677
::SingletonReferencesLowering,
77-
::JvmUpgradeCallableReferences,
7878
::JvmSharedVariablesLowering,
7979

8080
::JvmInventNamesForLocalFunctions,

compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/JvmTailrecLowering.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlin.backend.jvm.ir.defaultValue
1111
import org.jetbrains.kotlin.config.LanguageFeature
1212
import org.jetbrains.kotlin.ir.expressions.IrExpression
1313
import org.jetbrains.kotlin.ir.expressions.IrFunctionReference
14+
import org.jetbrains.kotlin.ir.expressions.IrRichFunctionReference
1415
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
1516
import org.jetbrains.kotlin.ir.types.IrType
1617

@@ -21,6 +22,9 @@ internal class JvmTailrecLowering(context: JvmBackendContext) : TailrecLowering(
2122
override fun followFunctionReference(reference: IrFunctionReference): Boolean =
2223
reference.origin == IrStatementOrigin.INLINE_LAMBDA
2324

25+
override fun followRichFunctionReference(reference: IrRichFunctionReference): Boolean =
26+
reference.origin == IrStatementOrigin.INLINE_LAMBDA
27+
2428
override fun nullConst(startOffset: Int, endOffset: Int, type: IrType): IrExpression =
2529
type.defaultValue(startOffset, endOffset, context as JvmBackendContext)
2630
}

0 commit comments

Comments
 (0)