|
18 | 18 | package test.org.apache.spark.sql;
|
19 | 19 |
|
20 | 20 | import java.io.Serializable;
|
| 21 | +import java.time.Instant; |
| 22 | +import java.time.LocalDate; |
21 | 23 | import java.util.*;
|
22 | 24 |
|
23 | 25 | import org.apache.spark.sql.*;
|
24 | 26 | import org.apache.spark.sql.catalyst.expressions.GenericRow;
|
| 27 | +import org.apache.spark.sql.catalyst.util.DateTimeUtils; |
| 28 | +import org.apache.spark.sql.catalyst.util.TimestampFormatter; |
| 29 | +import org.apache.spark.sql.internal.SQLConf; |
25 | 30 | import org.apache.spark.sql.types.DataTypes;
|
26 | 31 | import org.apache.spark.sql.types.StructType;
|
27 | 32 | import org.junit.*;
|
@@ -509,4 +514,95 @@ public void setId(Integer id) {
|
509 | 514 | this.id = id;
|
510 | 515 | }
|
511 | 516 | }
|
| 517 | + |
| 518 | + @Test |
| 519 | + public void testBeanWithLocalDateAndInstant() { |
| 520 | + String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key()); |
| 521 | + try { |
| 522 | + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true"); |
| 523 | + List<Row> inputRows = new ArrayList<>(); |
| 524 | + List<LocalDateInstantRecord> expectedRecords = new ArrayList<>(); |
| 525 | + |
| 526 | + for (long idx = 0 ; idx < 5 ; idx++) { |
| 527 | + Row row = createLocalDateInstantRow(idx); |
| 528 | + inputRows.add(row); |
| 529 | + expectedRecords.add(createLocalDateInstantRecord(row)); |
| 530 | + } |
| 531 | + |
| 532 | + Encoder<LocalDateInstantRecord> encoder = Encoders.bean(LocalDateInstantRecord.class); |
| 533 | + |
| 534 | + StructType schema = new StructType() |
| 535 | + .add("localDateField", DataTypes.DateType) |
| 536 | + .add("instantField", DataTypes.TimestampType); |
| 537 | + |
| 538 | + Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema); |
| 539 | + Dataset<LocalDateInstantRecord> dataset = dataFrame.as(encoder); |
| 540 | + |
| 541 | + List<LocalDateInstantRecord> records = dataset.collectAsList(); |
| 542 | + |
| 543 | + Assert.assertEquals(expectedRecords, records); |
| 544 | + } finally { |
| 545 | + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf); |
| 546 | + } |
| 547 | + } |
| 548 | + |
| 549 | + public static final class LocalDateInstantRecord { |
| 550 | + private String localDateField; |
| 551 | + private String instantField; |
| 552 | + |
| 553 | + public LocalDateInstantRecord() { } |
| 554 | + |
| 555 | + public String getLocalDateField() { |
| 556 | + return localDateField; |
| 557 | + } |
| 558 | + |
| 559 | + public void setLocalDateField(String localDateField) { |
| 560 | + this.localDateField = localDateField; |
| 561 | + } |
| 562 | + |
| 563 | + public String getInstantField() { |
| 564 | + return instantField; |
| 565 | + } |
| 566 | + |
| 567 | + public void setInstantField(String instantField) { |
| 568 | + this.instantField = instantField; |
| 569 | + } |
| 570 | + |
| 571 | + @Override |
| 572 | + public boolean equals(Object o) { |
| 573 | + if (this == o) return true; |
| 574 | + if (o == null || getClass() != o.getClass()) return false; |
| 575 | + LocalDateInstantRecord that = (LocalDateInstantRecord) o; |
| 576 | + return Objects.equals(localDateField, that.localDateField) && |
| 577 | + Objects.equals(instantField, that.instantField); |
| 578 | + } |
| 579 | + |
| 580 | + @Override |
| 581 | + public int hashCode() { |
| 582 | + return Objects.hash(localDateField, instantField); |
| 583 | + } |
| 584 | + |
| 585 | + @Override |
| 586 | + public String toString() { |
| 587 | + return com.google.common.base.Objects.toStringHelper(this) |
| 588 | + .add("localDateField", localDateField) |
| 589 | + .add("instantField", instantField) |
| 590 | + .toString(); |
| 591 | + } |
| 592 | + } |
| 593 | + |
| 594 | + private static Row createLocalDateInstantRow(Long index) { |
| 595 | + Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) }; |
| 596 | + return new GenericRow(values); |
| 597 | + } |
| 598 | + |
| 599 | + private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) { |
| 600 | + LocalDateInstantRecord record = new LocalDateInstantRecord(); |
| 601 | + record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0))); |
| 602 | + Instant instant = recordRow.getInstant(1); |
| 603 | + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter( |
| 604 | + DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone())); |
| 605 | + record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant))); |
| 606 | + return record; |
| 607 | + } |
512 | 608 | }
|
0 commit comments