Skip to content

Commit a00980d

Browse files
Handle parsing negative years in dates (#71)
time.Time can represent dates with a negative year, however the time.ParseInLocation function does not handle parsing date strings with a negative year. - Added our own parseInLocation function which handles parsing with negative years. - Added unit tests for date/time parsing. Signed-off-by: Raymond Cypher <[email protected]>
1 parent 0b7f22f commit a00980d

File tree

2 files changed

+144
-26
lines changed

2 files changed

+144
-26
lines changed

rows.go

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
3737
var _ driver.RowsColumnTypeNullable = (*rows)(nil)
3838
var _ driver.RowsColumnTypeLength = (*rows)(nil)
3939

40-
var errRowsFetchPriorToStart = "unable to fetch row page prior to start of results"
41-
var errRowsNoSchemaAvailable = "no schema in result set metadata response"
42-
var errRowsNoClient = "instance of Rows missing client"
43-
var errRowsNilRows = "nil Rows instance"
40+
var errRowsFetchPriorToStart = "databricks: unable to fetch row page prior to start of results"
41+
var errRowsNoSchemaAvailable = "databricks: no schema in result set metadata response"
42+
var errRowsNoClient = "databricks: instance of Rows missing client"
43+
var errRowsNilRows = "databricks: nil Rows instance"
44+
var errRowsParseValue = "databricks: unable to parse %s value '%s' from column %s"
4445

4546
func NewRows(connID string, corrId string, client cli_service.TCLIService, opHandle *cli_service.TOperationHandle, pageSize int64, location *time.Location, directResults *cli_service.TSparkDirectResults) driver.Rows {
4647
r := &rows{
@@ -240,7 +241,7 @@ var (
240241
scanTypeString = reflect.TypeOf("")
241242
scanTypeDateTime = reflect.TypeOf(time.Time{})
242243
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
243-
scanTypeUnknown = reflect.TypeOf(new(interface{}))
244+
scanTypeUnknown = reflect.TypeOf(new(any))
244245
)
245246

246247
func getScanType(column *cli_service.TColumnDesc) reflect.Type {
@@ -454,13 +455,12 @@ func (r *rows) getPageStartRowNum() int64 {
454455
return r.fetchResults.GetResults().GetStartRowOffset()
455456
}
456457

457-
const (
458-
// TimestampFormat is JDBC compliant timestamp format
459-
TimestampFormat = "2006-01-02 15:04:05.999999999"
460-
DateFormat = "2006-01-02"
461-
)
458+
var dateTimeFormats map[string]string = map[string]string{
459+
"TIMESTAMP": "2006-01-02 15:04:05.999999999",
460+
"DATE": "2006-01-02",
461+
}
462462

463-
func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, rowNum int64, location *time.Location) (val interface{}, err error) {
463+
func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, rowNum int64, location *time.Location) (val any, err error) {
464464
if location == nil {
465465
location = time.UTC
466466
}
@@ -469,17 +469,7 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r
469469
dbtype := strings.TrimSuffix(entry.Type.String(), "_TYPE")
470470
if tVal := tColumn.GetStringVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) {
471471
val = tVal.Values[rowNum]
472-
if dbtype == "TIMESTAMP" {
473-
t, err := time.ParseInLocation(TimestampFormat, val.(string), location)
474-
if err == nil {
475-
val = t
476-
}
477-
} else if dbtype == "DATE" {
478-
t, err := time.ParseInLocation(DateFormat, val.(string), location)
479-
if err == nil {
480-
val = t
481-
}
482-
}
472+
val, err = handleDateTime(val, dbtype, tColumnDesc.ColumnName, location)
483473
} else if tVal := tColumn.GetByteVal(); tVal != nil && !isNull(tVal.Nulls, rowNum) {
484474
val = tVal.Values[rowNum]
485475
} else if tVal := tColumn.GetI16Val(); tVal != nil && !isNull(tVal.Nulls, rowNum) {
@@ -499,6 +489,21 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r
499489
return val, err
500490
}
501491

492+
// handleDateTime will convert the passed val to a time.Time value if necessary
493+
func handleDateTime(val any, dbType, columnName string, location *time.Location) (any, error) {
494+
// if there is a date/time format corresponding to the column type we need to
495+
// convert to time.Time
496+
if format, ok := dateTimeFormats[dbType]; ok {
497+
t, err := parseInLocation(format, val.(string), location)
498+
if err != nil {
499+
err = wrapErrf(err, errRowsParseValue, dbType, val, columnName)
500+
}
501+
return t, err
502+
}
503+
504+
return val, nil
505+
}
506+
502507
func isNull(nulls []byte, position int64) bool {
503508
index := position / 8
504509
if int64(len(nulls)) > index {
@@ -540,3 +545,57 @@ func getNRows(rs *cli_service.TRowSet) int64 {
540545
}
541546
return 0
542547
}
548+
549+
// parseInLocation parses a date/time string in the given format and using the provided
550+
// location.
551+
// This is, essentially, a wrapper around time.ParseInLocation to handle negative year
552+
// values
553+
func parseInLocation(format, dateTimeString string, loc *time.Location) (time.Time, error) {
554+
// we want to handle dates with negative year values and currently we only
555+
// support formats that start with the year so we can just strip a leading minus
556+
// sign
557+
var isNegative bool
558+
dateTimeString, isNegative = stripLeadingNegative(dateTimeString)
559+
560+
date, err := time.ParseInLocation(format, dateTimeString, loc)
561+
if err != nil {
562+
return time.Time{}, err
563+
}
564+
565+
if isNegative {
566+
date = date.AddDate(-2*date.Year(), 0, 0)
567+
}
568+
569+
return date, nil
570+
}
571+
572+
// stripLeadingNegative will remove a leading ascii or unicode minus
573+
// if present. The possibly shortened string is returned and a flag indicating if
574+
// the string was altered
575+
func stripLeadingNegative(dateTimeString string) (string, bool) {
576+
if dateStartsWithNegative(dateTimeString) {
577+
// strip leading rune from dateTimeString
578+
// using range because it is supposed to be faster than utf8.DecodeRuneInString
579+
for i := range dateTimeString {
580+
if i > 0 {
581+
return dateTimeString[i:], true
582+
}
583+
}
584+
}
585+
586+
return dateTimeString, false
587+
}
588+
589+
// ISO 8601 allows for both the ascii and unicode characters for minus
590+
const (
591+
// unicode minus sign
592+
uMinus string = "\u2212"
593+
// ascii hyphen/minus
594+
aMinus string = "\x2D"
595+
)
596+
597+
// dateStartsWithNegative returns true if the string starts with
598+
// a minus sign
599+
func dateStartsWithNegative(val string) bool {
600+
return strings.HasPrefix(val, aMinus) || strings.HasPrefix(val, uMinus)
601+
}

rows_test.go

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import (
44
"context"
55
"database/sql/driver"
66
"errors"
7+
"fmt"
78
"io"
89
"math"
910
"reflect"
11+
"strings"
1012
"testing"
1113
"time"
1214

@@ -537,8 +539,8 @@ func TestNextNoDirectResults(t *testing.T) {
537539
row := make([]driver.Value, len(colNames))
538540

539541
err = rowSet.Next(row)
540-
timestamp, _ := time.Parse(TimestampFormat, "2021-07-01 05:43:28")
541-
date, _ := time.Parse(DateFormat, "2021-07-01")
542+
timestamp, _ := time.Parse(dateTimeFormats["TIMESTAMP"], "2021-07-01 05:43:28")
543+
date, _ := time.Parse(dateTimeFormats["DATE"], "2021-07-01")
542544
row0 := []driver.Value{
543545
true,
544546
driver.Value(nil),
@@ -592,8 +594,8 @@ func TestNextWithDirectResults(t *testing.T) {
592594

593595
err := rowSet.Next(row)
594596

595-
timestamp, _ := time.Parse(TimestampFormat, "2021-07-01 05:43:28")
596-
date, _ := time.Parse(DateFormat, "2021-07-01")
597+
timestamp, _ := time.Parse(dateTimeFormats["TIMESTAMP"], "2021-07-01 05:43:28")
598+
date, _ := time.Parse(dateTimeFormats["DATE"], "2021-07-01")
597599
row0 := []driver.Value{
598600
true,
599601
driver.Value(nil),
@@ -621,6 +623,63 @@ func TestNextWithDirectResults(t *testing.T) {
621623
assert.Equal(t, 1, fetchResultsCount)
622624
}
623625

626+
func TestHandlingDateTime(t *testing.T) {
627+
t.Run("should do nothing if data is not a date/time", func(t *testing.T) {
628+
val, err := handleDateTime("this is not a date", "STRING", "string_col", time.UTC)
629+
assert.Nil(t, err, "handleDateTime should do nothing if a column is not a date/time")
630+
assert.Equal(t, "this is not a date", val)
631+
})
632+
633+
t.Run("should error on invalid date/time value", func(t *testing.T) {
634+
_, err := handleDateTime("this is not a date", "DATE", "date_col", time.UTC)
635+
assert.NotNil(t, err)
636+
assert.True(t, strings.HasPrefix(err.Error(), fmt.Sprintf(errRowsParseValue, "DATE", "this is not a date", "date_col")))
637+
})
638+
639+
t.Run("should parse valid date", func(t *testing.T) {
640+
dt, err := handleDateTime("2006-12-22", "DATE", "date_col", time.UTC)
641+
assert.Nil(t, err)
642+
assert.Equal(t, time.Date(2006, 12, 22, 0, 0, 0, 0, time.UTC), dt)
643+
})
644+
645+
t.Run("should parse valid timestamp", func(t *testing.T) {
646+
dt, err := handleDateTime("2006-12-22 17:13:11.000001000", "TIMESTAMP", "timestamp_col", time.UTC)
647+
assert.Nil(t, err)
648+
assert.Equal(t, time.Date(2006, 12, 22, 17, 13, 11, 1000, time.UTC), dt)
649+
})
650+
651+
t.Run("should parse date with negative year", func(t *testing.T) {
652+
expectedTime := time.Date(-2006, 12, 22, 0, 0, 0, 0, time.UTC)
653+
dateStrings := []string{
654+
"-2006-12-22",
655+
"\u22122006-12-22",
656+
"\x2D2006-12-22",
657+
}
658+
659+
for _, s := range dateStrings {
660+
dt, err := handleDateTime(s, "DATE", "date_col", time.UTC)
661+
assert.Nil(t, err)
662+
assert.Equal(t, expectedTime, dt)
663+
}
664+
})
665+
666+
t.Run("should parse timestamp with negative year", func(t *testing.T) {
667+
expectedTime := time.Date(-2006, 12, 22, 17, 13, 11, 1000, time.UTC)
668+
669+
timestampStrings := []string{
670+
"-2006-12-22 17:13:11.000001000",
671+
"\u22122006-12-22 17:13:11.000001000",
672+
"\x2D2006-12-22 17:13:11.000001000",
673+
}
674+
675+
for _, s := range timestampStrings {
676+
dt, err := handleDateTime(s, "TIMESTAMP", "timestamp_col", time.UTC)
677+
assert.Nil(t, err)
678+
assert.Equal(t, expectedTime, dt)
679+
}
680+
})
681+
}
682+
624683
func TestGetScanType(t *testing.T) {
625684
var getMetadataCount, fetchResultsCount int
626685

0 commit comments

Comments
 (0)