@@ -37,10 +37,11 @@ var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
3737var _ driver.RowsColumnTypeNullable = (* rows )(nil )
3838var _ driver.RowsColumnTypeLength = (* rows )(nil )
3939
40- var errRowsFetchPriorToStart = "unable to fetch row page prior to start of results"
41- var errRowsNoSchemaAvailable = "no schema in result set metadata response"
42- var errRowsNoClient = "instance of Rows missing client"
43- var errRowsNilRows = "nil Rows instance"
40+ var errRowsFetchPriorToStart = "databricks: unable to fetch row page prior to start of results"
41+ var errRowsNoSchemaAvailable = "databricks: no schema in result set metadata response"
42+ var errRowsNoClient = "databricks: instance of Rows missing client"
43+ var errRowsNilRows = "databricks: nil Rows instance"
44+ var errRowsParseValue = "databricks: unable to parse %s value '%s' from column %s"
4445
4546func NewRows (connID string , corrId string , client cli_service.TCLIService , opHandle * cli_service.TOperationHandle , pageSize int64 , location * time.Location , directResults * cli_service.TSparkDirectResults ) driver.Rows {
4647 r := & rows {
@@ -240,7 +241,7 @@ var (
240241 scanTypeString = reflect .TypeOf ("" )
241242 scanTypeDateTime = reflect .TypeOf (time.Time {})
242243 scanTypeRawBytes = reflect .TypeOf (sql.RawBytes {})
243- scanTypeUnknown = reflect .TypeOf (new (interface {} ))
244+ scanTypeUnknown = reflect .TypeOf (new (any ))
244245)
245246
246247func getScanType (column * cli_service.TColumnDesc ) reflect.Type {
@@ -454,13 +455,12 @@ func (r *rows) getPageStartRowNum() int64 {
454455 return r .fetchResults .GetResults ().GetStartRowOffset ()
455456}
456457
457- const (
458- // TimestampFormat is JDBC compliant timestamp format
459- TimestampFormat = "2006-01-02 15:04:05.999999999"
460- DateFormat = "2006-01-02"
461- )
458+ var dateTimeFormats map [string ]string = map [string ]string {
459+ "TIMESTAMP" : "2006-01-02 15:04:05.999999999" ,
460+ "DATE" : "2006-01-02" ,
461+ }
462462
463- func value (tColumn * cli_service.TColumn , tColumnDesc * cli_service.TColumnDesc , rowNum int64 , location * time.Location ) (val interface {} , err error ) {
463+ func value (tColumn * cli_service.TColumn , tColumnDesc * cli_service.TColumnDesc , rowNum int64 , location * time.Location ) (val any , err error ) {
464464 if location == nil {
465465 location = time .UTC
466466 }
@@ -469,17 +469,7 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r
469469 dbtype := strings .TrimSuffix (entry .Type .String (), "_TYPE" )
470470 if tVal := tColumn .GetStringVal (); tVal != nil && ! isNull (tVal .Nulls , rowNum ) {
471471 val = tVal .Values [rowNum ]
472- if dbtype == "TIMESTAMP" {
473- t , err := time .ParseInLocation (TimestampFormat , val .(string ), location )
474- if err == nil {
475- val = t
476- }
477- } else if dbtype == "DATE" {
478- t , err := time .ParseInLocation (DateFormat , val .(string ), location )
479- if err == nil {
480- val = t
481- }
482- }
472+ val , err = handleDateTime (val , dbtype , tColumnDesc .ColumnName , location )
483473 } else if tVal := tColumn .GetByteVal (); tVal != nil && ! isNull (tVal .Nulls , rowNum ) {
484474 val = tVal .Values [rowNum ]
485475 } else if tVal := tColumn .GetI16Val (); tVal != nil && ! isNull (tVal .Nulls , rowNum ) {
@@ -499,6 +489,21 @@ func value(tColumn *cli_service.TColumn, tColumnDesc *cli_service.TColumnDesc, r
499489 return val , err
500490}
501491
492+ // handleDateTime will convert the passed val to a time.Time value if necessary
493+ func handleDateTime (val any , dbType , columnName string , location * time.Location ) (any , error ) {
494+ // if there is a date/time format corresponding to the column type we need to
495+ // convert to time.Time
496+ if format , ok := dateTimeFormats [dbType ]; ok {
497+ t , err := parseInLocation (format , val .(string ), location )
498+ if err != nil {
499+ err = wrapErrf (err , errRowsParseValue , dbType , val , columnName )
500+ }
501+ return t , err
502+ }
503+
504+ return val , nil
505+ }
506+
502507func isNull (nulls []byte , position int64 ) bool {
503508 index := position / 8
504509 if int64 (len (nulls )) > index {
@@ -540,3 +545,57 @@ func getNRows(rs *cli_service.TRowSet) int64 {
540545 }
541546 return 0
542547}
548+
549+ // parseInLocation parses a date/time string in the given format and using the provided
550+ // location.
551+ // This is, essentially, a wrapper around time.ParseInLocation to handle negative year
552+ // values
553+ func parseInLocation (format , dateTimeString string , loc * time.Location ) (time.Time , error ) {
554+ // we want to handle dates with negative year values and currently we only
555+ // support formats that start with the year so we can just strip a leading minus
556+ // sign
557+ var isNegative bool
558+ dateTimeString , isNegative = stripLeadingNegative (dateTimeString )
559+
560+ date , err := time .ParseInLocation (format , dateTimeString , loc )
561+ if err != nil {
562+ return time.Time {}, err
563+ }
564+
565+ if isNegative {
566+ date = date .AddDate (- 2 * date .Year (), 0 , 0 )
567+ }
568+
569+ return date , nil
570+ }
571+
572+ // stripLeadingNegative will remove a leading ascii or unicode minus
573+ // if present. The possibly shortened string is returned and a flag indicating if
574+ // the string was altered
575+ func stripLeadingNegative (dateTimeString string ) (string , bool ) {
576+ if dateStartsWithNegative (dateTimeString ) {
577+ // strip leading rune from dateTimeString
578+ // using range because it is supposed to be faster than utf8.DecodeRuneInString
579+ for i := range dateTimeString {
580+ if i > 0 {
581+ return dateTimeString [i :], true
582+ }
583+ }
584+ }
585+
586+ return dateTimeString , false
587+ }
588+
589+ // ISO 8601 allows for both the ascii and unicode characters for minus
590+ const (
591+ // unicode minus sign
592+ uMinus string = "\u2212 "
593+ // ascii hyphen/minus
594+ aMinus string = "\x2D "
595+ )
596+
597+ // dateStartsWithNegative returns true if the string starts with
598+ // a minus sign
599+ func dateStartsWithNegative (val string ) bool {
600+ return strings .HasPrefix (val , aMinus ) || strings .HasPrefix (val , uMinus )
601+ }
0 commit comments