Skip to content

Commit db78074

Browse files
authored
Better handling of generics in prism creation (#3881)
1 parent d6898f5 commit db78074

File tree

6 files changed

+68
-33
lines changed

6 files changed

+68
-33
lines changed

arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/domain.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ data class Focus(
128128
val refinedType: KSType?,
129129
val onlyOneSealedSubclass: Boolean = false,
130130
val subclasses: List<String> = emptyList(),
131-
val classNameWithParameters: String? = className,
131+
val targetClassNameWithParameters: String? = className,
132+
val targetTypeParameters: List<String>? = emptyList(),
132133
) {
133134
val escapedParamName = paramName.plusIfNotBlank(
134135
prefix = "`",

arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/dsl.kt

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ fun OpticsProcessorOptions.generateIsoDsl(ele: ADT, isoOptic: ValueClassDsl): Sn
4747
private fun OpticsProcessorOptions.processLensSyntax(ele: ADT, foci: List<Focus>, className: String, lensType: String, optionalType: String, traversalType: String): String = if (ele.typeParameters.isEmpty()) {
4848
foci.joinToString(separator = "\n") { focus ->
4949
"""
50-
|${ele.visibilityModifierName} $inlineText val <__S> $lensType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $lensType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
51-
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
52-
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
50+
|${ele.visibilityModifierName} $inlineText val <__S> $lensType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $lensType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
51+
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
52+
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
5353
|
5454
""".trimMargin()
5555
}
@@ -58,9 +58,9 @@ private fun OpticsProcessorOptions.processLensSyntax(ele: ADT, foci: List<Focus>
5858
val joinedTypeParams = ele.typeParameters.joinToString(separator = ",")
5959
foci.joinToString(separator = "\n") { focus ->
6060
"""
61-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $lensType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $lensType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
62-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
63-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
61+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $lensType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $lensType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
62+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
63+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
6464
|
6565
""".trimMargin()
6666
}
@@ -69,23 +69,24 @@ private fun OpticsProcessorOptions.processLensSyntax(ele: ADT, foci: List<Focus>
6969
private fun OpticsProcessorOptions.processPrismSyntax(ele: ADT, dsl: SealedClassDsl, className: String, optionalType: String, prismType: String, traversalType: String): String = if (ele.typeParameters.isEmpty()) {
7070
dsl.foci.joinToString(separator = "\n\n") { focus ->
7171
"""
72-
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
73-
|${ele.visibilityModifierName} $inlineText val <__S> $prismType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $prismType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
74-
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
72+
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
73+
|${ele.visibilityModifierName} $inlineText val <__S> $prismType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $prismType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
74+
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
7575
|
7676
""".trimMargin()
7777
}
7878
} else {
7979
dsl.foci.joinToString(separator = "\n\n") { focus ->
8080
val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameterNames}"
81+
val allTypeParams = focus.refinedArguments.union(focus.targetTypeParameters.orEmpty())
8182
val joinedTypeParams = when {
82-
focus.refinedArguments.isEmpty() -> ""
83-
else -> focus.refinedArguments.joinToString(separator = ",")
83+
allTypeParams.isEmpty() -> ""
84+
else -> allTypeParams.joinToString(separator = ",")
8485
}
8586
"""
86-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
87-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $prismType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $prismType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
88-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
87+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
88+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $prismType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $prismType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
89+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
8990
|
9091
""".trimMargin()
9192
}
@@ -94,11 +95,11 @@ private fun OpticsProcessorOptions.processPrismSyntax(ele: ADT, dsl: SealedClass
9495
private fun OpticsProcessorOptions.processIsoSyntax(ele: ADT, dsl: ValueClassDsl, className: String, isoType: String, lensType: String, optionalType: String, prismType: String, traversalType: String): String = if (ele.typeParameters.isEmpty()) {
9596
dsl.foci.joinToString(separator = "\n\n") { focus ->
9697
"""
97-
|${ele.visibilityModifierName} $inlineText val <__S> $isoType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $isoType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
98-
|${ele.visibilityModifierName} $inlineText val <__S> $lensType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $lensType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
99-
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
100-
|${ele.visibilityModifierName} $inlineText val <__S> $prismType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $prismType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
101-
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.classNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
98+
|${ele.visibilityModifierName} $inlineText val <__S> $isoType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $isoType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
99+
|${ele.visibilityModifierName} $inlineText val <__S> $lensType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $lensType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
100+
|${ele.visibilityModifierName} $inlineText val <__S> $optionalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $optionalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
101+
|${ele.visibilityModifierName} $inlineText val <__S> $prismType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $prismType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
102+
|${ele.visibilityModifierName} $inlineText val <__S> $traversalType<__S, ${ele.sourceClassName}>.${focus.escapedParamName}: $traversalType<__S, ${focus.targetClassNameWithParameters}> $inlineText get() = this + $className.${focus.escapedParamName}
102103
|
103104
""".trimMargin()
104105
}
@@ -110,11 +111,11 @@ private fun OpticsProcessorOptions.processIsoSyntax(ele: ADT, dsl: ValueClassDsl
110111
else -> focus.refinedArguments.joinToString(separator = ",")
111112
}
112113
"""
113-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $isoType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $isoType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
114-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $lensType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $lensType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
115-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
116-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $prismType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $prismType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
117-
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.classNameWithParameters}> = this + $className.${focus.escapedParamName}()
114+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $isoType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $isoType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
115+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $lensType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $lensType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
116+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $optionalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $optionalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
117+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $prismType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $prismType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
118+
|${ele.visibilityModifierName} $inlineText fun <__S,$joinedTypeParams> $traversalType<__S, $sourceClassNameWithParams>.${focus.escapedParamName}(): $traversalType<__S, ${focus.targetClassNameWithParameters}> = this + $className.${focus.escapedParamName}()
118119
|
119120
""".trimMargin()
120121
}

arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/lenses.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ private fun OpticsProcessorOptions.processElement(adt: ADT, foci: List<Focus>, l
2121
} = value)"""
2222
when {
2323
focus.subclasses.isNotEmpty() -> {
24-
"""when(${adt.sourceName}) {
24+
// there's a cast with generics
25+
val needsCast = focus.subclasses.any { it.contains("<*") }
26+
val suppress = if (needsCast) """@Suppress("UNCHECKED_CAST")""" else ""
27+
val cast = if (needsCast) " as $sourceClassNameWithParams" else ""
28+
"""$suppress
29+
|when(${adt.sourceName}) {
2530
|${focus.subclasses.joinToString(separator = "\n") { "is $it -> $setBodyCopy" }}
26-
|}
31+
|}$cast
2732
|
2833
""".trimMargin()
2934
}

arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/prism.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ internal fun OpticsProcessorOptions.generatePrisms(ele: ADT, target: PrismTarget
1313
private fun OpticsProcessorOptions.processElement(ele: ADT, foci: List<Focus>): String = foci.joinToString(separator = "\n\n") { focus ->
1414
val sourceClassNameWithParams =
1515
focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameterNames}"
16+
val joinedTypeParams = focus.refinedArguments.union(focus.targetTypeParameters.orEmpty())
1617
val angledTypeParameters = when {
17-
focus.refinedArguments.isEmpty() -> ""
18-
else -> focus.refinedArguments.joinToString(prefix = "<", separator = ",", postfix = ">")
18+
joinedTypeParams.isEmpty() -> ""
19+
else -> joinedTypeParams.joinToString(prefix = "<", separator = ",", postfix = ">")
1920
}
2021
val firstLine = when {
2122
ele.typeParameters.isEmpty() ->
22-
"${ele.visibilityModifierName} $inlineText val ${ele.sourceClassName}.Companion.${focus.escapedParamName}: $Prism<${ele.sourceClassName}, ${focus.classNameWithParameters}> $inlineText get()"
23+
"${ele.visibilityModifierName} $inlineText val ${ele.sourceClassName}.Companion.${focus.escapedParamName}: $Prism<${ele.sourceClassName}, ${focus.targetClassNameWithParameters}> $inlineText get()"
2324

2425
else ->
25-
"${ele.visibilityModifierName} $inlineText fun $angledTypeParameters ${ele.sourceClassName}.Companion.${focus.escapedParamName}(): $Prism<$sourceClassNameWithParams, ${focus.classNameWithParameters}>"
26+
"${ele.visibilityModifierName} $inlineText fun $angledTypeParameters ${ele.sourceClassName}.Companion.${focus.escapedParamName}(): $Prism<$sourceClassNameWithParams, ${focus.targetClassNameWithParameters}>"
2627
}
2728
"$firstLine = $Prism.instanceOf()"
2829
}

arrow-libs/optics/arrow-optics-ksp-plugin/src/main/kotlin/arrow/optics/plugin/internals/processor.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ internal fun evalAnnotatedPrismElement(
8585
subclass.simpleName.asString().replaceFirstChar { c -> c.lowercase(Locale.getDefault()) },
8686
subclass.superTypes.first().resolve(),
8787
onlyOneSealedSubclass = sealedSubclasses.size == 1,
88-
classNameWithParameters = subclass.qualifiedNameOrSimpleNameWithTypeParameters,
88+
targetClassNameWithParameters = subclass.qualifiedNameOrSimpleNameWithTypeParameters,
89+
targetTypeParameters = subclass.typeParameters.map { it.simpleName.asString() },
8990
)
9091
}
9192
}
@@ -105,6 +106,12 @@ internal val KSClassDeclaration.qualifiedNameOrSimpleNameWithTypeParameters: Str
105106
else -> "$qualifiedNameOrSimpleName<${typeParameters.joinToString { it.simpleName.asString() }}>"
106107
}
107108

109+
internal val KSClassDeclaration.qualifiedNameOrSimpleNameWithStars: String
110+
get() = when {
111+
typeParameters.isEmpty() -> qualifiedNameOrSimpleName
112+
else -> "$qualifiedNameOrSimpleName<${typeParameters.joinToString { "*" }}>"
113+
}
114+
108115
internal fun evalAnnotatedDataClass(
109116
element: KSClassDeclaration,
110117
errorMessage: String,
@@ -159,7 +166,7 @@ internal fun evalAnnotatedDataClass(
159166
Focus(
160167
fullName = type,
161168
paramName = name,
162-
subclasses = subclasses.map { it.qualifiedNameOrSimpleName }.toList(),
169+
subclasses = subclasses.map { it.qualifiedNameOrSimpleNameWithStars }.toList(),
163170
)
164171
}
165172
.toList()

arrow-libs/optics/arrow-optics-ksp-plugin/src/test/kotlin/arrow/optics/plugin/DSLTests.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,4 +286,24 @@ class DSLTests {
286286
)
287287
compilationSucceeds(allWarningsAsErrors = false, contextParameters = false, source1, source2)
288288
}
289+
290+
@Test
291+
fun `Complicated hierarchy with generics (#3735)`() {
292+
"""
293+
|$`package`
294+
|$imports
295+
|
296+
|@optics
297+
|sealed interface Test<T> {
298+
| val value: String
299+
|
300+
| data class Test1<T>(override val value: String) : Test<T>
301+
| data class Test2(override val value: String) : Test<Int>
302+
| data class Test3<T, A>(override val value: String) : Test<T>
303+
| data class Test4<B>(override val value: String) : Test<List<B>>
304+
305+
| companion object
306+
|}
307+
""".compilationSucceeds()
308+
}
289309
}

0 commit comments

Comments
 (0)