Skip to content

Commit 7d16776

Browse files
mike0svsrowen
authored andcommitted
[SPARK-21255][SQL][WIP] Fixed NPE when creating encoder for enum
## What changes were proposed in this pull request? Fixed NPE when creating encoder for enum. When you try to create an encoder for Enum type (or bean with enum property) via Encoders.bean(...), it fails with NullPointerException at TypeToken:495. I did a little research and it turns out, that in JavaTypeInference following code ``` def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") .filter(_.getReadMethod != null) } ``` filters out properties named "class", because we wouldn't want to serialize that. But enum types have another property of type Class named "declaringClass", which we are trying to inspect recursively. Eventually we try to inspect ClassLoader class, which has property "defaultAssertionStatus" with no read method, which leads to NPE at TypeToken:495. I added property name "declaringClass" to filtering to resolve this. ## How was this patch tested? Unit test in JavaDatasetSuite which creates an encoder for enum Author: mike <[email protected]> Author: Mikhail Sveshnikov <[email protected]> Closes apache#18488 from mike0sv/enum-support.
1 parent f3676d6 commit 7d16776

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
3232
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
3333
import org.apache.spark.sql.types._
3434
import org.apache.spark.unsafe.types.UTF8String
35+
import org.apache.spark.util.Utils
3536

3637
/**
3738
* Type-inference utilities for POJOs and Java collections.
@@ -118,6 +119,10 @@ object JavaTypeInference {
118119
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
119120
(MapType(keyDataType, valueDataType, nullable), true)
120121

122+
case other if other.isEnum =>
123+
(StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
124+
StringType, nullable = false))), true)
125+
121126
case other =>
122127
if (seenTypeSet.contains(other)) {
123128
throw new UnsupportedOperationException(
@@ -140,6 +145,7 @@ object JavaTypeInference {
140145
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
141146
val beanInfo = Introspector.getBeanInfo(beanClass)
142147
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
148+
.filterNot(_.getName == "declaringClass")
143149
.filter(_.getReadMethod != null)
144150
}
145151

@@ -303,6 +309,11 @@ object JavaTypeInference {
303309
keyData :: valueData :: Nil,
304310
returnNullable = false)
305311

312+
case other if other.isEnum =>
313+
StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName",
314+
expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other))
315+
:: getPath :: Nil)
316+
306317
case other =>
307318
val properties = getJavaBeanReadableAndWritableProperties(other)
308319
val setters = properties.map { p =>
@@ -345,6 +356,30 @@ object JavaTypeInference {
345356
}
346357
}
347358

359+
/** Returns a mapping from enum value to int for given enum type */
360+
def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
361+
assert(enum.isEnum)
362+
inputObject: T =>
363+
UTF8String.fromString(inputObject.name())
364+
}
365+
366+
/** Returns value index for given enum type and value */
367+
def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = {
368+
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
369+
}
370+
371+
/** Returns a mapping from int to enum value for given enum type */
372+
def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
373+
assert(enum.isEnum)
374+
value: InternalRow =>
375+
Enum.valueOf(enum, value.getUTF8String(0).toString)
376+
}
377+
378+
/** Returns enum value for given enum type and value index */
379+
def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = {
380+
enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
381+
}
382+
348383
private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
349384

350385
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
@@ -429,6 +464,11 @@ object JavaTypeInference {
429464
valueNullable = true
430465
)
431466

467+
case other if other.isEnum =>
468+
CreateNamedStruct(expressions.Literal("enum") ::
469+
StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName",
470+
expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil)
471+
432472
case other =>
433473
val properties = getJavaBeanReadableAndWritableProperties(other)
434474
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
2828
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
2929
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
3030
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
31-
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
31+
import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType}
3232
import org.apache.spark.util.Utils
3333

3434
/**
@@ -81,9 +81,19 @@ object ExpressionEncoder {
8181
ClassTag[T](cls))
8282
}
8383

84+
def javaEnumSchema[T](beanClass: Class[T]): DataType = {
85+
StructType(Seq(StructField("enum",
86+
StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))),
87+
nullable = false)))
88+
}
89+
8490
// TODO: improve error message for java bean encoder.
8591
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
86-
val schema = JavaTypeInference.inferDataType(beanClass)._1
92+
val schema = if (beanClass.isEnum) {
93+
javaEnumSchema(beanClass)
94+
} else {
95+
JavaTypeInference.inferDataType(beanClass)._1
96+
}
8797
assert(schema.isInstanceOf[StructType])
8898

8999
val serializer = JavaTypeInference.serializerFor(beanClass)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ case class StaticInvoke(
154154
val evaluate = if (returnNullable) {
155155
if (ctx.defaultValue(dataType) == "null") {
156156
s"""
157-
${ev.value} = $callFunc;
157+
${ev.value} = (($javaType) ($callFunc));
158158
${ev.isNull} = ${ev.value} == null;
159159
"""
160160
} else {
161161
val boxedResult = ctx.freshName("boxedResult")
162162
s"""
163-
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
163+
${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
164164
${ev.isNull} = $boxedResult == null;
165165
if (!${ev.isNull}) {
166166
${ev.value} = $boxedResult;

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,83 @@ public void test() {
12831283
ds.collectAsList();
12841284
}
12851285

1286+
public enum EnumBean {
1287+
A("www.elgoog.com"),
1288+
B("www.google.com");
1289+
1290+
private String url;
1291+
1292+
EnumBean(String url) {
1293+
this.url = url;
1294+
}
1295+
1296+
public String getUrl() {
1297+
return url;
1298+
}
1299+
1300+
public void setUrl(String url) {
1301+
this.url = url;
1302+
}
1303+
}
1304+
1305+
@Test
1306+
public void testEnum() {
1307+
List<EnumBean> data = Arrays.asList(EnumBean.B);
1308+
Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
1309+
Dataset<EnumBean> ds = spark.createDataset(data, encoder);
1310+
Assert.assertEquals(ds.collectAsList(), data);
1311+
}
1312+
1313+
public static class BeanWithEnum {
1314+
EnumBean enumField;
1315+
String regularField;
1316+
1317+
public String getRegularField() {
1318+
return regularField;
1319+
}
1320+
1321+
public void setRegularField(String regularField) {
1322+
this.regularField = regularField;
1323+
}
1324+
1325+
public EnumBean getEnumField() {
1326+
return enumField;
1327+
}
1328+
1329+
public void setEnumField(EnumBean field) {
1330+
this.enumField = field;
1331+
}
1332+
1333+
public BeanWithEnum(EnumBean enumField, String regularField) {
1334+
this.enumField = enumField;
1335+
this.regularField = regularField;
1336+
}
1337+
1338+
public BeanWithEnum() {
1339+
}
1340+
1341+
public String toString() {
1342+
return "BeanWithEnum(" + enumField + ", " + regularField + ")";
1343+
}
1344+
1345+
public boolean equals(Object other) {
1346+
if (other instanceof BeanWithEnum) {
1347+
BeanWithEnum beanWithEnum = (BeanWithEnum) other;
1348+
return beanWithEnum.regularField.equals(regularField) && beanWithEnum.enumField.equals(enumField);
1349+
}
1350+
return false;
1351+
}
1352+
}
1353+
1354+
@Test
1355+
public void testBeanWithEnum() {
1356+
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"),
1357+
new BeanWithEnum(EnumBean.B, "flower boulevard"));
1358+
Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
1359+
Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
1360+
Assert.assertEquals(ds.collectAsList(), data);
1361+
}
1362+
12861363
public static class EmptyBean implements Serializable {}
12871364

12881365
@Test

0 commit comments

Comments
 (0)