Skip to content

Commit 1bc6723

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-27344][SQL][TEST] Support the LocalDate and Instant classes in Java Bean encoders
## What changes were proposed in this pull request? - Added new test for Java Bean encoder of the classes: `java.time.LocalDate` and `java.time.Instant`. - Updated comment for `Encoders.bean` - New Row getters: `getLocalDate` and `getInstant` - Extended `inferDataType` to infer types for `java.time.LocalDate` -> `DateType` and `java.time.Instant` -> `TimestampType`. ## How was this patch tested? By `JavaBeanDeserializationSuite` Closes apache#24273 from MaxGekk/bean-instant-localdate. Lead-authored-by: Maxim Gekk <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3286bff commit 1bc6723

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ object Encoders {
149149
* - boxed types: Boolean, Integer, Double, etc.
150150
* - String
151151
* - java.math.BigDecimal, java.math.BigInteger
152-
* - time related: java.sql.Date, java.sql.Timestamp
152+
* - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant
153153
* - collection types: only array and java.util.List currently, map support is in progress
154154
* - nested java bean.
155155
*

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,27 @@ trait Row extends Serializable {
269269
*/
270270
def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i)
271271

272+
/**
273+
* Returns the value at position i of date type as java.time.LocalDate.
274+
*
275+
* @throws ClassCastException when data type does not match.
276+
*/
277+
def getLocalDate(i: Int): java.time.LocalDate = getAs[java.time.LocalDate](i)
278+
272279
/**
273280
* Returns the value at position i of date type as java.sql.Timestamp.
274281
*
275282
* @throws ClassCastException when data type does not match.
276283
*/
277284
def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i)
278285

286+
/**
287+
* Returns the value at position i of date type as java.time.Instant.
288+
*
289+
* @throws ClassCastException when data type does not match.
290+
*/
291+
def getInstant(i: Int): java.time.Instant = getAs[java.time.Instant](i)
292+
279293
/**
280294
* Returns the value at position i of array type as a Scala Seq.
281295
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ object JavaTypeInference {
102102

103103
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
104104
case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
105+
case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true)
105106
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
107+
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
106108
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
107109

108110
case _ if typeToken.isArray =>

sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
package test.org.apache.spark.sql;
1919

2020
import java.io.Serializable;
21+
import java.time.Instant;
22+
import java.time.LocalDate;
2123
import java.util.*;
2224

2325
import org.apache.spark.sql.*;
2426
import org.apache.spark.sql.catalyst.expressions.GenericRow;
27+
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
28+
import org.apache.spark.sql.catalyst.util.TimestampFormatter;
29+
import org.apache.spark.sql.internal.SQLConf;
2530
import org.apache.spark.sql.types.DataTypes;
2631
import org.apache.spark.sql.types.StructType;
2732
import org.junit.*;
@@ -509,4 +514,95 @@ public void setId(Integer id) {
509514
this.id = id;
510515
}
511516
}
517+
518+
@Test
519+
public void testBeanWithLocalDateAndInstant() {
520+
String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key());
521+
try {
522+
spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true");
523+
List<Row> inputRows = new ArrayList<>();
524+
List<LocalDateInstantRecord> expectedRecords = new ArrayList<>();
525+
526+
for (long idx = 0 ; idx < 5 ; idx++) {
527+
Row row = createLocalDateInstantRow(idx);
528+
inputRows.add(row);
529+
expectedRecords.add(createLocalDateInstantRecord(row));
530+
}
531+
532+
Encoder<LocalDateInstantRecord> encoder = Encoders.bean(LocalDateInstantRecord.class);
533+
534+
StructType schema = new StructType()
535+
.add("localDateField", DataTypes.DateType)
536+
.add("instantField", DataTypes.TimestampType);
537+
538+
Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
539+
Dataset<LocalDateInstantRecord> dataset = dataFrame.as(encoder);
540+
541+
List<LocalDateInstantRecord> records = dataset.collectAsList();
542+
543+
Assert.assertEquals(expectedRecords, records);
544+
} finally {
545+
spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf);
546+
}
547+
}
548+
549+
public static final class LocalDateInstantRecord {
550+
private String localDateField;
551+
private String instantField;
552+
553+
public LocalDateInstantRecord() { }
554+
555+
public String getLocalDateField() {
556+
return localDateField;
557+
}
558+
559+
public void setLocalDateField(String localDateField) {
560+
this.localDateField = localDateField;
561+
}
562+
563+
public String getInstantField() {
564+
return instantField;
565+
}
566+
567+
public void setInstantField(String instantField) {
568+
this.instantField = instantField;
569+
}
570+
571+
@Override
572+
public boolean equals(Object o) {
573+
if (this == o) return true;
574+
if (o == null || getClass() != o.getClass()) return false;
575+
LocalDateInstantRecord that = (LocalDateInstantRecord) o;
576+
return Objects.equals(localDateField, that.localDateField) &&
577+
Objects.equals(instantField, that.instantField);
578+
}
579+
580+
@Override
581+
public int hashCode() {
582+
return Objects.hash(localDateField, instantField);
583+
}
584+
585+
@Override
586+
public String toString() {
587+
return com.google.common.base.Objects.toStringHelper(this)
588+
.add("localDateField", localDateField)
589+
.add("instantField", instantField)
590+
.toString();
591+
}
592+
}
593+
594+
private static Row createLocalDateInstantRow(Long index) {
595+
Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) };
596+
return new GenericRow(values);
597+
}
598+
599+
private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) {
600+
LocalDateInstantRecord record = new LocalDateInstantRecord();
601+
record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0)));
602+
Instant instant = recordRow.getInstant(1);
603+
TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(
604+
DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone()));
605+
record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant)));
606+
return record;
607+
}
512608
}

0 commit comments

Comments
 (0)