Skip to content

Commit 6327ea5

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-21255][SQL] simplify encoder for java enum
## What changes were proposed in this pull request? This is a follow-up for apache#18488, to simplify the code. The major change is, we should map java enum to string type, instead of a struct type with a single string field. ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes apache#19066 from cloud-fan/fix.
1 parent 8fcbda9 commit 6327ea5

File tree

4 files changed

+25
-63
lines changed

4 files changed

+25
-63
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.objects._
3232
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
3333
import org.apache.spark.sql.types._
3434
import org.apache.spark.unsafe.types.UTF8String
35-
import org.apache.spark.util.Utils
3635

3736
/**
3837
* Type-inference utilities for POJOs and Java collections.
@@ -120,8 +119,7 @@ object JavaTypeInference {
120119
(MapType(keyDataType, valueDataType, nullable), true)
121120

122121
case other if other.isEnum =>
123-
(StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
124-
StringType, nullable = false))), true)
122+
(StringType, true)
125123

126124
case other =>
127125
if (seenTypeSet.contains(other)) {
@@ -310,9 +308,12 @@ object JavaTypeInference {
310308
returnNullable = false)
311309

312310
case other if other.isEnum =>
313-
StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName",
314-
expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other))
315-
:: getPath :: Nil)
311+
StaticInvoke(
312+
other,
313+
ObjectType(other),
314+
"valueOf",
315+
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
316+
returnNullable = false)
316317

317318
case other =>
318319
val properties = getJavaBeanReadableAndWritableProperties(other)
@@ -356,30 +357,6 @@ object JavaTypeInference {
356357
}
357358
}
358359

359-
/** Returns a mapping from enum value to int for given enum type */
360-
def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
361-
assert(enum.isEnum)
362-
inputObject: T =>
363-
UTF8String.fromString(inputObject.name())
364-
}
365-
366-
/** Returns value index for given enum type and value */
367-
def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = {
368-
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
369-
}
370-
371-
/** Returns a mapping from int to enum value for given enum type */
372-
def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
373-
assert(enum.isEnum)
374-
value: InternalRow =>
375-
Enum.valueOf(enum, value.getUTF8String(0).toString)
376-
}
377-
378-
/** Returns enum value for given enum type and value index */
379-
def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = {
380-
enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
381-
}
382-
383360
private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
384361

385362
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
@@ -465,9 +442,12 @@ object JavaTypeInference {
465442
)
466443

467444
case other if other.isEnum =>
468-
CreateNamedStruct(expressions.Literal("enum") ::
469-
StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName",
470-
expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil)
445+
StaticInvoke(
446+
classOf[UTF8String],
447+
StringType,
448+
"fromString",
449+
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
450+
returnNullable = false)
471451

472452
case other =>
473453
val properties = getJavaBeanReadableAndWritableProperties(other)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
2828
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
2929
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
3030
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
31-
import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType}
31+
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
3232
import org.apache.spark.util.Utils
3333

3434
/**
@@ -81,19 +81,9 @@ object ExpressionEncoder {
8181
ClassTag[T](cls))
8282
}
8383

84-
def javaEnumSchema[T](beanClass: Class[T]): DataType = {
85-
StructType(Seq(StructField("enum",
86-
StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))),
87-
nullable = false)))
88-
}
89-
9084
// TODO: improve error message for java bean encoder.
9185
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
92-
val schema = if (beanClass.isEnum) {
93-
javaEnumSchema(beanClass)
94-
} else {
95-
JavaTypeInference.inferDataType(beanClass)._1
96-
}
86+
val schema = JavaTypeInference.inferDataType(beanClass)._1
9787
assert(schema.isInstanceOf[StructType])
9888

9989
val serializer = JavaTypeInference.serializerFor(beanClass)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ case class StaticInvoke(
154154
val evaluate = if (returnNullable) {
155155
if (ctx.defaultValue(dataType) == "null") {
156156
s"""
157-
${ev.value} = (($javaType) ($callFunc));
157+
${ev.value} = $callFunc;
158158
${ev.isNull} = ${ev.value} == null;
159159
"""
160160
} else {
161161
val boxedResult = ctx.freshName("boxedResult")
162162
s"""
163-
${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
163+
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
164164
${ev.isNull} = $boxedResult == null;
165165
if (!${ev.isNull}) {
166166
${ev.value} = $boxedResult;

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,13 +1283,13 @@ public void test() {
12831283
ds.collectAsList();
12841284
}
12851285

1286-
public enum EnumBean {
1286+
public enum MyEnum {
12871287
A("www.elgoog.com"),
12881288
B("www.google.com");
12891289

12901290
private String url;
12911291

1292-
EnumBean(String url) {
1292+
MyEnum(String url) {
12931293
this.url = url;
12941294
}
12951295

@@ -1302,16 +1302,8 @@ public void setUrl(String url) {
13021302
}
13031303
}
13041304

1305-
@Test
1306-
public void testEnum() {
1307-
List<EnumBean> data = Arrays.asList(EnumBean.B);
1308-
Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
1309-
Dataset<EnumBean> ds = spark.createDataset(data, encoder);
1310-
Assert.assertEquals(ds.collectAsList(), data);
1311-
}
1312-
13131305
public static class BeanWithEnum {
1314-
EnumBean enumField;
1306+
MyEnum enumField;
13151307
String regularField;
13161308

13171309
public String getRegularField() {
@@ -1322,15 +1314,15 @@ public void setRegularField(String regularField) {
13221314
this.regularField = regularField;
13231315
}
13241316

1325-
public EnumBean getEnumField() {
1317+
public MyEnum getEnumField() {
13261318
return enumField;
13271319
}
13281320

1329-
public void setEnumField(EnumBean field) {
1321+
public void setEnumField(MyEnum field) {
13301322
this.enumField = field;
13311323
}
13321324

1333-
public BeanWithEnum(EnumBean enumField, String regularField) {
1325+
public BeanWithEnum(MyEnum enumField, String regularField) {
13341326
this.enumField = enumField;
13351327
this.regularField = regularField;
13361328
}
@@ -1353,8 +1345,8 @@ public boolean equals(Object other) {
13531345

13541346
@Test
13551347
public void testBeanWithEnum() {
1356-
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"),
1357-
new BeanWithEnum(EnumBean.B, "flower boulevard"));
1348+
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira avenue"),
1349+
new BeanWithEnum(MyEnum.B, "flower boulevard"));
13581350
Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
13591351
Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
13601352
Assert.assertEquals(ds.collectAsList(), data);

0 commit comments

Comments
 (0)