Skip to content

Commit 8b1605a

Browse files
authored
Merge pull request github#9405 from smowton/smowton/fix/restore-wildcard-types
Kotlin: Introduce / restore implied wildcard types
2 parents d7b06aa + efc534a commit 8b1605a

File tree

8 files changed

+457
-138
lines changed

8 files changed

+457
-138
lines changed

java/kotlin-extractor/src/main/kotlin/KotlinFileExtractor.kt

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ import org.jetbrains.kotlin.ir.expressions.*
1919
import org.jetbrains.kotlin.ir.symbols.*
2020
import org.jetbrains.kotlin.ir.types.*
2121
import org.jetbrains.kotlin.ir.util.*
22+
import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
23+
import org.jetbrains.kotlin.load.java.structure.JavaClass
24+
import org.jetbrains.kotlin.load.java.structure.JavaTypeParameter
2225
import org.jetbrains.kotlin.name.FqName
26+
import org.jetbrains.kotlin.types.Variance
2327
import org.jetbrains.kotlin.util.OperatorNameConventions
2428
import java.io.Closeable
2529
import java.util.*
@@ -171,7 +175,7 @@ open class KotlinFileExtractor(
171175
}
172176
}
173177

174-
fun extractTypeParameter(tp: IrTypeParameter, apparentIndex: Int): Label<out DbTypevariable>? {
178+
fun extractTypeParameter(tp: IrTypeParameter, apparentIndex: Int, javaTypeParameter: JavaTypeParameter?): Label<out DbTypevariable>? {
175179
with("type parameter", tp) {
176180
val parentId = getTypeParameterParentLabel(tp) ?: return null
177181
val id = tw.getLabelFor<DbTypevariable>(getTypeParameterLabel(tp))
@@ -183,10 +187,21 @@ open class KotlinFileExtractor(
183187
val locId = tw.getLocation(tp)
184188
tw.writeHasLocation(id, locId)
185189

190+
// Annoyingly, we have no obvious way to pair up the bounds of an IrTypeParameter and a JavaTypeParameter
191+
// because JavaTypeParameter provides a Collection not an ordered list, so we can only do our best here:
192+
fun tryGetJavaBound(idx: Int) =
193+
when(tp.superTypes.size) {
194+
1 -> javaTypeParameter?.upperBounds?.singleOrNull()
195+
else -> (javaTypeParameter?.upperBounds as? List)?.getOrNull(idx)
196+
}
197+
186198
tp.superTypes.forEachIndexed { boundIdx, bound ->
187199
if(!(bound.isAny() || bound.isNullableAny())) {
188200
tw.getLabelFor<DbTypebound>("@\"bound;$boundIdx;{$id}\"") {
189-
tw.writeTypeBounds(it, useType(bound).javaResult.id.cast<DbReftype>(), boundIdx, id)
201+
// Note we don't look for @JvmSuppressWildcards here because it doesn't seem to have any impact
202+
// on kotlinc adding wildcards to type parameter bounds.
203+
val boundWithWildcards = addJavaLoweringWildcards(bound, true, tryGetJavaBound(tp.index))
204+
tw.writeTypeBounds(it, useType(boundWithWildcards).javaResult.id.cast<DbReftype>(), boundIdx, id)
190205
}
191206
}
192207
}
@@ -382,7 +397,9 @@ open class KotlinFileExtractor(
382397

383398
extractEnclosingClass(c, id, locId, listOf())
384399

385-
c.typeParameters.mapIndexed { idx, param -> extractTypeParameter(param, idx) }
400+
val javaClass = (c.source as? JavaSourceElement)?.javaElement as? JavaClass
401+
402+
c.typeParameters.mapIndexed { idx, param -> extractTypeParameter(param, idx, javaClass?.typeParameters?.getOrNull(idx)) }
386403
if (extractDeclarations) {
387404
c.declarations.map { extractDeclaration(it, extractPrivateMembers = extractPrivateMembers, extractFunctionBodies = extractFunctionBodies) }
388405
if (extractStaticInitializer)
@@ -497,7 +514,9 @@ open class KotlinFileExtractor(
497514
else
498515
null
499516
} ?: vp.type
500-
val substitutedType = typeSubstitution?.let { it(maybeErasedType, TypeContext.OTHER, pluginContext) } ?: maybeErasedType
517+
val javaType = ((vp.parent as? IrFunction)?.let { getJavaMethod(it) })?.valueParameters?.getOrNull(idx)?.type
518+
val typeWithWildcards = addJavaLoweringWildcards(maybeErasedType, !hasWildcardSuppressionAnnotation(vp), javaType)
519+
val substitutedType = typeSubstitution?.let { it(typeWithWildcards, TypeContext.OTHER, pluginContext) } ?: typeWithWildcards
501520
val id = useValueParameter(vp, parent)
502521
if (extractTypeAccess) {
503522
extractTypeAccessRecursive(substitutedType, location, id, -1)
@@ -531,7 +550,9 @@ open class KotlinFileExtractor(
531550
extensionReceiverParameter = null,
532551
functionTypeParameters = listOf(),
533552
classTypeArgsIncludingOuterClasses = listOf(),
534-
overridesCollectionsMethod = false
553+
overridesCollectionsMethod = false,
554+
javaSignature = null,
555+
addParameterWildcardsByDefault = false
535556
)
536557
val clinitId = tw.getLabelFor<DbMethod>(clinitLabel)
537558
val returnType = useType(pluginContext.irBuiltIns.unitType, TypeContext.RETURN)
@@ -670,7 +691,8 @@ open class KotlinFileExtractor(
670691
with("function", f) {
671692
DeclarationStackAdjuster(f).use {
672693

673-
getFunctionTypeParameters(f).mapIndexed { idx, tp -> extractTypeParameter(tp, idx) }
694+
val javaMethod = getJavaMethod(f)
695+
getFunctionTypeParameters(f).mapIndexed { idx, tp -> extractTypeParameter(tp, idx, javaMethod?.typeParameters?.getOrNull(idx)) }
674696

675697
val id =
676698
idOverride
@@ -704,7 +726,7 @@ open class KotlinFileExtractor(
704726

705727
val paramsSignature = allParamTypes.joinToString(separator = ",", prefix = "(", postfix = ")") { it.javaResult.signature!! }
706728

707-
val adjustedReturnType = getAdjustedReturnType(f)
729+
val adjustedReturnType = addJavaLoweringWildcards(getAdjustedReturnType(f), false, javaMethod?.returnType)
708730
val substReturnType = typeSubstitution?.let { it(adjustedReturnType, TypeContext.RETURN, pluginContext) } ?: adjustedReturnType
709731

710732
val locId = locOverride ?: getLocation(f, classTypeArgsIncludingOuterClasses)
@@ -3744,6 +3766,17 @@ open class KotlinFileExtractor(
37443766
}
37453767
}
37463768

3769+
/**
3770+
* Extracts a single wildcard type access expression with no enclosing callable and statement.
3771+
*/
3772+
private fun extractWildcardTypeAccess(type: TypeResults, location: Label<DbLocation>, parent: Label<out DbExprparent>, idx: Int): Label<out DbExpr> {
3773+
val id = tw.getFreshIdLabel<DbWildcardtypeaccess>()
3774+
tw.writeExprs_wildcardtypeaccess(id, type.javaResult.id, parent, idx)
3775+
tw.writeExprsKotlinType(id, type.kotlinResult.id)
3776+
tw.writeHasLocation(id, location)
3777+
return id
3778+
}
3779+
37473780
/**
37483781
* Extracts a single type access expression with no enclosing callable and statement.
37493782
*/
@@ -3768,15 +3801,36 @@ open class KotlinFileExtractor(
37683801
return id
37693802
}
37703803

3804+
/**
3805+
* Extracts a type argument type access, introducing a wildcard type access if appropriate, or directly calling
3806+
* `extractTypeAccessRecursive` if the argument is invariant.
3807+
* No enclosing callable and statement is extracted, this is useful for type access extraction in field declarations.
3808+
*/
3809+
private fun extractWildcardTypeAccessRecursive(t: IrTypeArgument, location: Label<DbLocation>, parent: Label<out DbExprparent>, idx: Int) {
3810+
val typeLabels by lazy { TypeResults(getTypeArgumentLabel(t), TypeResult(fakeKotlinType(), "TODO", "TODO")) }
3811+
when (t) {
3812+
is IrStarProjection -> extractWildcardTypeAccess(typeLabels, location, parent, idx)
3813+
is IrTypeProjection -> when(t.variance) {
3814+
Variance.INVARIANT -> extractTypeAccessRecursive(t.type, location, parent, idx, TypeContext.GENERIC_ARGUMENT)
3815+
else -> {
3816+
val wildcardLabel = extractWildcardTypeAccess(typeLabels, location, parent, idx)
3817+
// Mimic a Java extractor oddity, that it uses the child index to indicate what kind of wildcard this is
3818+
val boundChildIdx = if (t.variance == Variance.OUT_VARIANCE) 0 else 1
3819+
extractTypeAccessRecursive(t.type, location, wildcardLabel, boundChildIdx, TypeContext.GENERIC_ARGUMENT)
3820+
}
3821+
}
3822+
}
3823+
}
3824+
37713825
/**
37723826
* Extracts a type access expression and its child type access expressions in case of a generic type. Nested generics are also handled.
37733827
* No enclosing callable and statement is extracted, this is useful for type access extraction in field declarations.
37743828
*/
37753829
private fun extractTypeAccessRecursive(t: IrType, location: Label<DbLocation>, parent: Label<out DbExprparent>, idx: Int, typeContext: TypeContext = TypeContext.OTHER): Label<out DbExpr> {
37763830
val typeAccessId = extractTypeAccess(useType(t, typeContext), location, parent, idx)
37773831
if (t is IrSimpleType) {
3778-
t.arguments.filterIsInstance<IrType>().forEachIndexed { argIdx, arg ->
3779-
extractTypeAccessRecursive(arg, location, typeAccessId, argIdx, TypeContext.GENERIC_ARGUMENT)
3832+
t.arguments.forEachIndexed { argIdx, arg ->
3833+
extractWildcardTypeAccessRecursive(arg, location, typeAccessId, argIdx)
37803834
}
37813835
}
37823836
return typeAccessId

java/kotlin-extractor/src/main/kotlin/KotlinUsesExtractor.kt

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import com.github.codeql.utils.versions.isRawType
66
import com.semmle.extractor.java.OdasaOutput
77
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
88
import org.jetbrains.kotlin.backend.common.ir.allOverridden
9+
import org.jetbrains.kotlin.backend.common.ir.isFinalClass
10+
import org.jetbrains.kotlin.backend.common.lower.parents
911
import org.jetbrains.kotlin.backend.common.lower.parentsWithSelf
1012
import org.jetbrains.kotlin.backend.jvm.ir.getJvmNameFromAnnotation
1113
import org.jetbrains.kotlin.backend.jvm.ir.propertyIfAccessor
@@ -20,6 +22,8 @@ import org.jetbrains.kotlin.ir.types.impl.*
2022
import org.jetbrains.kotlin.ir.util.*
2123
import org.jetbrains.kotlin.load.java.BuiltinMethodsWithSpecialGenericSignature
2224
import org.jetbrains.kotlin.load.java.JvmAbi
25+
import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
26+
import org.jetbrains.kotlin.load.java.structure.*
2327
import org.jetbrains.kotlin.name.FqName
2428
import org.jetbrains.kotlin.name.Name
2529
import org.jetbrains.kotlin.name.SpecialNames
@@ -850,6 +854,64 @@ open class KotlinUsesExtractor(
850854
(f.name.asString() == "addAll" && overridesFunctionDefinedOn(f, "kotlin.collections", "MutableCollection")) ||
851855
(f.name.asString() == "addAll" && overridesFunctionDefinedOn(f, "kotlin.collections", "MutableList"))
852856

857+
858+
private val jvmWildcardAnnotation = FqName("kotlin.jvm.JvmWildcard")
859+
private val jvmWildcardSuppressionAnnotaton = FqName("kotlin.jvm.JvmSuppressWildcards")
860+
861+
private fun wildcardAdditionAllowed(v: Variance, t: IrType, addByDefault: Boolean) =
862+
when {
863+
t.hasAnnotation(jvmWildcardAnnotation) -> true
864+
!addByDefault -> false
865+
t.hasAnnotation(jvmWildcardSuppressionAnnotaton) -> false
866+
v == Variance.IN_VARIANCE -> !(t.isNullableAny() || t.isAny())
867+
v == Variance.OUT_VARIANCE -> ((t as? IrSimpleType)?.classOrNull?.owner?.isFinalClass) != true
868+
else -> false
869+
}
870+
871+
private fun addJavaLoweringArgumentWildcards(p: IrTypeParameter, t: IrTypeArgument, addByDefault: Boolean, javaType: JavaType?): IrTypeArgument =
872+
(t as? IrTypeProjection)?.let {
873+
val newBase = addJavaLoweringWildcards(it.type, addByDefault, javaType)
874+
val newVariance =
875+
if (it.variance == Variance.INVARIANT &&
876+
p.variance != Variance.INVARIANT &&
877+
// The next line forbids inferring a wildcard type when we have a corresponding Java type with conflicting variance.
878+
// For example, Java might declare f(Comparable<CharSequence> cs), in which case we shouldn't add a `? super ...`
879+
// wildcard. Note if javaType is unknown (e.g. this is a Kotlin source element), we assume wildcards should be added.
880+
(javaType?.let { jt -> jt is JavaWildcardType && jt.isExtends == (p.variance == Variance.OUT_VARIANCE) } != false) &&
881+
wildcardAdditionAllowed(p.variance, it.type, addByDefault))
882+
p.variance
883+
else
884+
it.variance
885+
if (newBase !== it.type || newVariance != it.variance)
886+
makeTypeProjection(newBase, newVariance)
887+
else
888+
null
889+
} ?: t
890+
891+
fun getJavaTypeArgument(jt: JavaType, idx: Int) =
892+
when(jt) {
893+
is JavaClassifierType -> jt.typeArguments.getOrNull(idx)
894+
is JavaArrayType -> if (idx == 0) jt.componentType else null
895+
else -> null
896+
}
897+
898+
fun addJavaLoweringWildcards(t: IrType, addByDefault: Boolean, javaType: JavaType?): IrType =
899+
(t as? IrSimpleType)?.let {
900+
val typeParams = it.classOrNull?.owner?.typeParameters ?: return t
901+
val newArgs = typeParams.zip(it.arguments).mapIndexed { idx, pair ->
902+
addJavaLoweringArgumentWildcards(
903+
pair.first,
904+
pair.second,
905+
addByDefault,
906+
javaType?.let { jt -> getJavaTypeArgument(jt, idx) }
907+
)
908+
}
909+
return if (newArgs.zip(it.arguments).all { pair -> pair.first === pair.second })
910+
t
911+
else
912+
it.toBuilder().also { builder -> builder.arguments = newArgs }.buildSimpleType()
913+
} ?: t
914+
853915
/*
854916
* This is the normal getFunctionLabel function to use. If you want
855917
* to refer to the function in its source class then
@@ -883,6 +945,14 @@ open class KotlinUsesExtractor(
883945
return otherKeySet.returnType.codeQlWithHasQuestionMark(false)
884946
}
885947

948+
@OptIn(ObsoleteDescriptorBasedAPI::class)
949+
fun getJavaMethod(f: IrFunction) = (f.descriptor.source as? JavaSourceElement)?.javaElement as? JavaMethod
950+
951+
fun hasWildcardSuppressionAnnotation(d: IrDeclaration) =
952+
d.hasAnnotation(jvmWildcardSuppressionAnnotaton) ||
953+
// Note not using `parentsWithSelf` as that only works if `d` is an IrDeclarationParent
954+
d.parents.any { (it as? IrAnnotationContainer)?.hasAnnotation(jvmWildcardSuppressionAnnotaton) == true }
955+
886956
/*
887957
* There are some pairs of classes (e.g. `kotlin.Throwable` and
888958
* `java.lang.Throwable`) which are really just 2 different names
@@ -903,7 +973,9 @@ open class KotlinUsesExtractor(
903973
f.extensionReceiverParameter,
904974
getFunctionTypeParameters(f),
905975
classTypeArgsIncludingOuterClasses,
906-
overridesCollectionsMethodWithAlteredParameterTypes(f)
976+
overridesCollectionsMethodWithAlteredParameterTypes(f),
977+
getJavaMethod(f),
978+
!hasWildcardSuppressionAnnotation(f)
907979
)
908980

909981
/*
@@ -933,6 +1005,11 @@ open class KotlinUsesExtractor(
9331005
// If true, this method implements a Java Collections interface (Collection, Map or List) and may need
9341006
// parameter erasure to match the way this class will appear to an external consumer of the .class file.
9351007
overridesCollectionsMethod: Boolean,
1008+
// The Java signature of this callable, if known.
1009+
javaSignature: JavaMethod?,
1010+
// If true, Java wildcards implied by Kotlin type parameter variance should be added by default to this function's value parameters' types.
1011+
// (Return-type wildcard addition is always off by default)
1012+
addParameterWildcardsByDefault: Boolean,
9361013
// The prefix used in the label. "callable", unless a property label is created, then it's "property".
9371014
prefix: String = "callable"
9381015
): String {
@@ -956,8 +1033,10 @@ open class KotlinUsesExtractor(
9561033
// Collection.remove(Object) because Collection.remove(Collection::E) in the Kotlin universe.
9571034
// If this has happened, erase the type again to get the correct Java signature.
9581035
val maybeAmendedForCollections = if (overridesCollectionsMethod) eraseCollectionsMethodParameterType(it.value.type, name, it.index) else it.value.type
1036+
// Add any wildcard types that the Kotlin compiler would add in the Java lowering of this function:
1037+
val withAddedWildcards = addJavaLoweringWildcards(maybeAmendedForCollections, addParameterWildcardsByDefault, javaSignature?.let { sig -> sig.valueParameters[it.index].type })
9591038
// Now substitute any class type parameters in:
960-
val maybeSubbed = maybeAmendedForCollections.substituteTypeAndArguments(substitutionMap, TypeContext.OTHER, pluginContext)
1039+
val maybeSubbed = withAddedWildcards.substituteTypeAndArguments(substitutionMap, TypeContext.OTHER, pluginContext)
9611040
// Finally, mimic the Java extractor's behaviour by naming functions with type parameters for their erased types;
9621041
// those without type parameters are named for the generic type.
9631042
val maybeErased = if (functionTypeParameters.isEmpty()) maybeSubbed else erase(maybeSubbed)
@@ -969,6 +1048,8 @@ open class KotlinUsesExtractor(
9691048
pluginContext.irBuiltIns.unitType
9701049
else
9711050
erase(returnType.substituteTypeAndArguments(substitutionMap, TypeContext.RETURN, pluginContext))
1051+
// Note that `addJavaLoweringWildcards` is not required here because the return type used to form the function
1052+
// label is always erased.
9721053
val returnTypeId = useType(labelReturnType, TypeContext.RETURN).javaResult.id
9731054
// This suffix is added to generic methods (and constructors) to match the Java extractor's behaviour.
9741055
// Comments in that extractor indicates it didn't want the label of the callable to clash with the raw
@@ -1425,7 +1506,7 @@ open class KotlinUsesExtractor(
14251506
val returnType = getter?.returnType ?: setter?.valueParameters?.singleOrNull()?.type ?: pluginContext.irBuiltIns.unitType
14261507
val typeParams = getFunctionTypeParameters(func)
14271508

1428-
getFunctionLabel(p.parent, parentId, p.name.asString(), listOf(), returnType, ext, typeParams, classTypeArgsIncludingOuterClasses, overridesCollectionsMethod = false, prefix = "property")
1509+
getFunctionLabel(p.parent, parentId, p.name.asString(), listOf(), returnType, ext, typeParams, classTypeArgsIncludingOuterClasses, overridesCollectionsMethod = false, javaSignature = null, addParameterWildcardsByDefault = false, prefix = "property")
14291510
}
14301511
}
14311512

0 commit comments

Comments
 (0)