@@ -103,41 +103,13 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
103103 arrowConfig = cfg .ArrowConfig
104104 }
105105
106- var arrowSchema * arrow.Schema
107- schemaBytes := resultSetMetadata .ArrowSchema
108- if schemaBytes == nil {
109- var err error
110- // convert the TTableSchema to an arrow Schema
111- arrowSchema , err = tTableSchemaToArrowSchema (resultSetMetadata .Schema , & arrowConfig )
112- if err != nil {
113- logger .Err (err ).Msg (errArrowRowsConvertSchema )
114- return nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsConvertSchema , err )
115- }
116-
117- // serialize the arrow schema
118- schemaBytes , err = getArrowSchemaBytes (arrowSchema , ctx )
119- if err != nil {
120- logger .Err (err ).Msg (errArrowRowsSerializeSchema )
121- return nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsSerializeSchema , err )
122- }
123- } else {
124- br := & chunkedByteReader {chunks : [][]byte {schemaBytes }}
125- rdr , err := ipc .NewReader (br )
126- if err != nil {
127- return nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsUnableToReadBatch , err )
128- }
129- defer rdr .Release ()
130-
131- arrowSchema = rdr .Schema ()
106+ schemaBytes , arrowSchema , metadataErr := tGetResultSetMetadataRespToArrowSchema (resultSetMetadata , arrowConfig , ctx , logger )
107+ if metadataErr != nil {
108+ return nil , metadataErr
132109 }
133110
134- // get the database type names for each column
135- colInfos := make ([]colInfo , len (resultSetMetadata .Schema .Columns ))
136- for i := range resultSetMetadata .Schema .Columns {
137- col := resultSetMetadata .Schema .Columns [i ]
138- field := arrowSchema .Field (i )
139- colInfos [i ] = colInfo {name : field .Name , arrowType : field .Type , dbType : rowscanner .GetDBType (col )}
140- }
111+ // Create column info
112+ colInfos := getColumnInfo (arrowSchema , resultSetMetadata .Schema )
141113
142114 // get the function for converting arrow timestamps to a time.Time
143115 // time values from the server are returned as UTC with microsecond precision
@@ -553,6 +525,61 @@ func tColumnDescToArrowField(columnDesc *cli_service.TColumnDesc, arrowConfig *c
553525 return arrowField , nil
554526}
555527
528+ // Build a slice of columnInfo using the arrow schema and the thrift schema
529+ func getColumnInfo (arrowSchema * arrow.Schema , schema * cli_service.TTableSchema ) []colInfo {
530+ if arrowSchema == nil || schema == nil {
531+ return []colInfo {}
532+ }
533+
534+ nFields := len (arrowSchema .Fields ())
535+ if len (schema .Columns ) < nFields {
536+ nFields = len (schema .Columns )
537+ }
538+
539+ colInfos := make ([]colInfo , nFields )
540+ for i := 0 ; i < nFields ; i ++ {
541+ col := schema .Columns [i ]
542+ field := arrowSchema .Field (i )
543+ colInfos [i ] = colInfo {name : field .Name , arrowType : field .Type , dbType : rowscanner .GetDBType (col )}
544+ }
545+
546+ return colInfos
547+ }
548+
549+ // Derive an arrow.Schema object and the corresponding serialized bytes from TGetResultSetMetadataResp
550+ func tGetResultSetMetadataRespToArrowSchema (resultSetMetadata * cli_service.TGetResultSetMetadataResp , arrowConfig config.ArrowConfig , ctx context.Context , logger * dbsqllog.DBSQLLogger ) ([]byte , * arrow.Schema , dbsqlerr.DBError ) {
551+
552+ var arrowSchema * arrow.Schema
553+ schemaBytes := resultSetMetadata .ArrowSchema
554+ if schemaBytes == nil {
555+ var err error
556+ // convert the TTableSchema to an arrow Schema
557+ arrowSchema , err = tTableSchemaToArrowSchema (resultSetMetadata .Schema , & arrowConfig )
558+ if err != nil {
559+ logger .Err (err ).Msg (errArrowRowsConvertSchema )
560+ return nil , nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsConvertSchema , err )
561+ }
562+
563+ // serialize the arrow schema
564+ schemaBytes , err = getArrowSchemaBytes (arrowSchema , ctx )
565+ if err != nil {
566+ logger .Err (err ).Msg (errArrowRowsSerializeSchema )
567+ return nil , nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsSerializeSchema , err )
568+ }
569+ } else {
570+ br := & chunkedByteReader {chunks : [][]byte {schemaBytes }}
571+ rdr , err := ipc .NewReader (br )
572+ if err != nil {
573+ return nil , nil , dbsqlerrint .NewDriverError (ctx , errArrowRowsUnableToReadBatch , err )
574+ }
575+ defer rdr .Release ()
576+
577+ arrowSchema = rdr .Schema ()
578+ }
579+
580+ return schemaBytes , arrowSchema , nil
581+ }
582+
556583type sparkRecordReader struct {}
557584
558585// Make sure sparkRecordReader fulfills the recordReader interface
0 commit comments