Skip to content

Commit 1b3bf1d

Browse files
committed
Support erased generics in toDataFrame conversion
1 parent 24fdb1b commit 1b3bf1d

File tree

3 files changed

+78
-25
lines changed

3 files changed

+78
-25
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataColumn.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,6 @@ public val AnyCol.indices: IntRange get() = indices()
133133

134134
public val AnyCol.type: KType get() = type()
135135
public val AnyCol.kind: ColumnKind get() = kind()
136-
public val AnyCol.typeClass: KClass<*> get() = type.classifier as KClass<*>
136+
public val AnyCol.typeClass: KClass<*> get() = type.classifier as? KClass<*> ?: error("Cannot cast ${type.classifier?.javaClass} to a ${KClass::class}. Column $name: $type")
137137

138138
public fun AnyBaseCol.indices(): IntRange = 0 until size()

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/toDataFrame.kt

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import kotlin.reflect.full.isSubclassOf
2929
import kotlin.reflect.full.memberProperties
3030
import kotlin.reflect.full.withNullability
3131
import kotlin.reflect.jvm.javaField
32+
import kotlin.reflect.typeOf
3233

3334
internal val valueTypes = setOf(
3435
String::class,
@@ -163,38 +164,56 @@ internal fun convertToDataFrame(
163164
}
164165
}
165166

166-
val type = property.returnType
167-
val kclass = (type.classifier as KClass<*>)
167+
val returnType = property.returnType.let { type ->
168+
if (type.classifier is KClass<*>) {
169+
type
170+
} else {
171+
typeOf<Any>()
172+
}
173+
}
174+
val kclass = (returnType.classifier as KClass<*>)
168175
when {
169176
hasExceptions -> DataColumn.createWithTypeInference(it.columnName, values, nullable)
170-
preserveClasses.contains(kclass) || preserveProperties.contains(property) || (maxDepth <= 0 && !type.shouldBeConvertedToFrameColumn() && !type.shouldBeConvertedToColumnGroup()) || kclass.isValueType ->
171-
DataColumn.createValueColumn(it.columnName, values, property.returnType.withNullability(nullable))
177+
kclass == Any::class || preserveClasses.contains(kclass) || preserveProperties.contains(property) || (maxDepth <= 0 && !returnType.shouldBeConvertedToFrameColumn() && !returnType.shouldBeConvertedToColumnGroup()) || kclass.isValueType ->
178+
DataColumn.createValueColumn(it.columnName, values, returnType.withNullability(nullable))
172179
kclass == DataFrame::class && !nullable -> DataColumn.createFrameColumn(it.columnName, values as List<AnyFrame>)
173180
kclass == DataRow::class -> DataColumn.createColumnGroup(it.columnName, (values as List<AnyRow>).concat())
174181
kclass.isSubclassOf(Iterable::class) -> {
175-
val elementType = type.projectUpTo(Iterable::class).arguments.firstOrNull()?.type
176-
if (elementType == null) DataColumn.createValueColumn(
177-
it.columnName,
178-
values,
179-
property.returnType.withNullability(nullable)
180-
)
181-
else {
182-
val elementClass = (elementType.classifier as KClass<*>)
183-
if (elementClass.isValueType) {
184-
val listType = getListType(elementType).withNullability(nullable)
185-
val listValues = values.map {
186-
(it as? Iterable<*>)?.asList()
182+
val elementType = returnType.projectUpTo(Iterable::class).arguments.firstOrNull()?.type
183+
if (elementType == null) {
184+
DataColumn.createValueColumn(
185+
it.columnName,
186+
values,
187+
returnType.withNullability(nullable)
188+
)
189+
} else {
190+
val elementClass = (elementType.classifier as? KClass<*>)
191+
192+
when {
193+
elementClass == null -> {
194+
val listValues = values.map {
195+
(it as? Iterable<*>)?.asList()
196+
}
197+
198+
DataColumn.createWithTypeInference(it.columnName, listValues)
199+
}
200+
elementClass.isValueType -> {
201+
val listType = getListType(elementType).withNullability(nullable)
202+
val listValues = values.map {
203+
(it as? Iterable<*>)?.asList()
204+
}
205+
DataColumn.createValueColumn(it.columnName, listValues, listType)
187206
}
188-
DataColumn.createValueColumn(it.columnName, listValues, listType)
189-
} else {
190-
val frames = values.map {
191-
if (it == null) DataFrame.empty()
192-
else {
193-
require(it is Iterable<*>)
194-
convertToDataFrame(it, elementClass, emptyList(), excludes, preserveClasses, preserveProperties, maxDepth - 1)
207+
else -> {
208+
val frames = values.map {
209+
if (it == null) DataFrame.empty()
210+
else {
211+
require(it is Iterable<*>)
212+
convertToDataFrame(it, elementClass, emptyList(), excludes, preserveClasses, preserveProperties, maxDepth - 1)
213+
}
195214
}
215+
DataColumn.createFrameColumn(it.columnName, frames)
196216
}
197-
DataColumn.createFrameColumn(it.columnName, frames)
198217
}
199218
}
200219
}

tests/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/toDataFrame.kt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,38 @@ class CreateDataFrameTests {
149149
val child2 = child1.asColumnGroup()[Child1::child]
150150
child2.kind shouldBe ColumnKind.Value
151151
}
152+
153+
@Test
154+
fun inferredTypeForPropertyWithGenericIterableType() {
155+
class Container<E>(val data: Set<E>)
156+
157+
val element = Container(setOf(1))
158+
val value = listOf(element).toDataFrame(maxDepth = 10)
159+
160+
value["data"].type() shouldBe typeOf<List<Int>>()
161+
}
162+
163+
@Test
164+
fun inferredNullableTypeForPropertyWithGenericIterableType() {
165+
class Container<E>(val data: List<E>)
166+
167+
val element = Container(listOf(1, null))
168+
val value = listOf(element).toDataFrame(maxDepth = 10)
169+
170+
value["data"].type() shouldBe typeOf<List<Int?>>()
171+
}
172+
173+
@Suppress("unused")
174+
@Test
175+
fun treatErasedGenericAsAny() {
176+
class IncompatibleVersionErrorData<T>(val expected: T, val actual: T)
177+
class DeserializedContainerSource(val incompatibility: IncompatibleVersionErrorData<*>)
178+
val functions = listOf(DeserializedContainerSource(IncompatibleVersionErrorData(1, 2)))
179+
180+
val df = functions.toDataFrame(maxDepth = 2)
181+
182+
val col = df.getColumnGroup(DeserializedContainerSource::incompatibility)
183+
col[IncompatibleVersionErrorData<*>::actual].type() shouldBe typeOf<Any>()
184+
col[IncompatibleVersionErrorData<*>::expected].type() shouldBe typeOf<Any>()
185+
}
152186
}

0 commit comments

Comments
 (0)