diff --git a/go/connection.go b/go/connection.go index 6a3c43ef..be5ec0ce 100644 --- a/go/connection.go +++ b/go/connection.go @@ -45,6 +45,9 @@ type connectionImpl struct { // Database connection conn *sql.Conn + + // Arrow serialization options + useArrowNativeGeospatial bool } func (c *connectionImpl) Close() error { diff --git a/go/database.go b/go/database.go index 495b2453..81bd7f4e 100644 --- a/go/database.go +++ b/go/database.go @@ -80,6 +80,9 @@ type databaseImpl struct { oauthClientID string oauthClientSecret string oauthRefreshToken string + + // Arrow serialization options + useArrowNativeGeospatial bool } func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) { @@ -148,6 +151,13 @@ func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) { opts = append(opts, dbsql.WithMaxDownloadThreads(d.downloadThreadCount)) } + // Arrow-native geospatial serialization (SPARK-54232). + // When enabled, geometry/geography columns arrive as Struct + // instead of EWKT strings, enabling native geometry passthrough. + if d.useArrowNativeGeospatial { + opts = append(opts, dbsql.WithArrowNativeGeospatial(true)) + } + // TLS/SSL handling // Configure a custom transport with proper timeout settings when custom // TLS config is needed. These settings match the defaults from @@ -251,10 +261,11 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { } conn := &connectionImpl{ - ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), - catalog: d.catalog, - dbSchema: d.schema, - conn: c, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + catalog: d.catalog, + dbSchema: d.schema, + conn: c, + useArrowNativeGeospatial: d.useArrowNativeGeospatial, } return driverbase.NewConnectionBuilder(conn). @@ -320,6 +331,11 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return d.oauthClientSecret, nil case OptionOAuthRefreshToken: return d.oauthRefreshToken, nil + case OptionArrowNativeGeospatial: + if d.useArrowNativeGeospatial { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil default: return d.DatabaseImplBase.GetOption(key) } @@ -486,6 +502,18 @@ func (d *databaseImpl) SetOption(key, value string) error { d.oauthClientSecret = value case OptionOAuthRefreshToken: d.oauthRefreshToken = value + case OptionArrowNativeGeospatial: + switch value { + case adbc.OptionValueEnabled: + d.useArrowNativeGeospatial = true + case adbc.OptionValueDisabled, "": + d.useArrowNativeGeospatial = false + default: + return adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("invalid value for %s: %s (expected 'true' or 'false')", OptionArrowNativeGeospatial, value), + } + } default: return d.DatabaseImplBase.SetOption(key, value) } diff --git a/go/driver.go b/go/driver.go index 3037355a..a20e26e6 100644 --- a/go/driver.go +++ b/go/driver.go @@ -67,6 +67,9 @@ const ( OptionOAuthClientSecret = "databricks.oauth.client_secret" OptionOAuthRefreshToken = "databricks.oauth.refresh_token" + // Arrow serialization options + OptionArrowNativeGeospatial = "databricks.arrow.native_geospatial" + // Default values DefaultPort = 443 DefaultSSLMode = "require" diff --git a/go/ipc_reader_adapter.go b/go/ipc_reader_adapter.go index 68381f29..98fbd5c1 100644 --- a/go/ipc_reader_adapter.go +++ b/go/ipc_reader_adapter.go @@ -45,13 +45,146 @@ type ipcReaderAdapter struct { currentReader *ipc.Reader currentRecord arrow.RecordBatch schema *arrow.Schema + rawSchema *arrow.Schema // original schema before geoarrow transform closed bool refCount int64 err error + + // geoarrow conversion: indices of geometry struct columns to flatten + geoColumnIndices []int + geoSchemaBuilt bool // whether geoarrow schema has been built from first batch +} + +// isGeometryStruct checks if a field is a Databricks geometry struct: +// Struct +func isGeometryStruct(field arrow.Field) bool { + st, ok := field.Type.(*arrow.StructType) + if !ok || st.NumFields() != 2 { + return false + } + f0 := st.Field(0) + f1 := st.Field(1) + return f0.Name == "srid" && f0.Type.ID() == arrow.INT32 && + f1.Name == "wkb" && f1.Type.ID() == arrow.BINARY +} + +// detectGeometryColumns finds geometry Struct columns in the schema. +func detectGeometryColumns(schema *arrow.Schema) []int { + var indices []int + for i, f := range schema.Fields() { + if isGeometryStruct(f) { + indices = append(indices, i) + } + } + return indices +} + +// buildGeoArrowSchemaWithoutCRS creates a geoarrow.wkb schema without CRS +// metadata. Used eagerly so the schema is available before the first Next(). +func buildGeoArrowSchemaWithoutCRS(schema *arrow.Schema, geoIndices []int) *arrow.Schema { + fields := schema.Fields() + newFields := make([]arrow.Field, len(fields)) + copy(newFields, fields) + + for _, idx := range geoIndices { + f := fields[idx] + newFields[idx] = arrow.Field{ + Name: f.Name, + Type: arrow.BinaryTypes.Binary, + Nullable: f.Nullable, + Metadata: arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": "", + }), + } + } + + meta := schema.Metadata() + return arrow.NewSchema(newFields, &meta) +} + +// buildGeoArrowSchema creates a new schema with geometry Struct fields replaced +// by Binary fields with geoarrow.wkb extension metadata. The SRID from the +// first record batch is used to populate the CRS in the extension metadata. +func buildGeoArrowSchema(schema *arrow.Schema, geoIndices []int, rec arrow.RecordBatch) *arrow.Schema { + fields := schema.Fields() + newFields := make([]arrow.Field, len(fields)) + copy(newFields, fields) + + for _, idx := range geoIndices { + f := fields[idx] + + // Read SRID from first non-null row of this geometry column + srid := 0 + structArr := rec.Column(idx).(*array.Struct) + sridArr := structArr.Field(0) + for row := 0; row < sridArr.Len(); row++ { + if !sridArr.IsNull(row) { + srid = int(sridArr.(*array.Int32).Value(row)) + break + } + } + + // Build geoarrow.wkb extension metadata with CRS from SRID + extMeta := "" + if srid != 0 { + extMeta = fmt.Sprintf(`{"crs":"EPSG:%d"}`, srid, srid) + } + + newFields[idx] = arrow.Field{ + Name: f.Name, + Type: arrow.BinaryTypes.Binary, + Nullable: f.Nullable, + Metadata: arrow.MetadataFrom(map[string]string{ + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": extMeta, + }), + } + } + + meta := schema.Metadata() + return arrow.NewSchema(newFields, &meta) +} + +// transformRecordForGeoArrow extracts the wkb child from geometry struct +// columns and builds a new record with flat Binary columns. +func transformRecordForGeoArrow(rec arrow.RecordBatch, schema *arrow.Schema, geoIndices []int) arrow.RecordBatch { + if len(geoIndices) == 0 { + return rec + } + + geoSet := make(map[int]bool, len(geoIndices)) + for _, idx := range geoIndices { + geoSet[idx] = true + } + + cols := make([]arrow.Array, rec.NumCols()) + for i := 0; i < int(rec.NumCols()); i++ { + if geoSet[i] { + // Extract the "wkb" field (index 1) from the struct array + structArr := rec.Column(i).(*array.Struct) + wkbArr := structArr.Field(1) + wkbArr.Retain() + cols[i] = wkbArr + } else { + col := rec.Column(i) + col.Retain() + cols[i] = col + } + } + + newRec := array.NewRecord(schema, cols, rec.NumRows()) + + // Release our references to the columns + for _, col := range cols { + col.Release() + } + + return newRec } // newIPCReaderAdapter creates a RecordReader using direct IPC stream access -func newIPCReaderAdapter(ctx context.Context, rows driver.Rows) (array.RecordReader, error) { +func newIPCReaderAdapter(ctx context.Context, rows driver.Rows, useArrowNativeGeospatial bool) (array.RecordReader, error) { ipcRows, ok := rows.(dbsqlrows.Rows) if !ok { return nil, adbc.Error{ @@ -127,6 +260,20 @@ func newIPCReaderAdapter(ctx context.Context, rows driver.Rows) (array.RecordRea } } + // When Arrow-native geospatial is enabled, detect geometry Struct columns + // and build a geoarrow.wkb schema. The schema must be available before + // the first Next() call since consumers (e.g. adbc_scanner) read it + // upfront to create table columns. We build the schema eagerly with + // empty CRS metadata, then enrich it with the SRID from the first + // record batch when available. + if useArrowNativeGeospatial { + adapter.geoColumnIndices = detectGeometryColumns(adapter.schema) + if len(adapter.geoColumnIndices) > 0 { + adapter.rawSchema = adapter.schema + adapter.schema = buildGeoArrowSchemaWithoutCRS(adapter.rawSchema, adapter.geoColumnIndices) + } + } + return adapter, nil } @@ -165,6 +312,23 @@ func (r *ipcReaderAdapter) Schema() *arrow.Schema { return r.schema } +// handleGeoRecord enriches the geoarrow schema with CRS from the first batch, +// then transforms the record to flatten geometry struct columns. +func (r *ipcReaderAdapter) handleGeoRecord(rec arrow.RecordBatch) arrow.RecordBatch { + if len(r.geoColumnIndices) == 0 { + return rec + } + + // On the first record batch, rebuild the schema with SRID-based CRS + // from the actual data. This replaces the initial empty-CRS schema. + if !r.geoSchemaBuilt { + r.schema = buildGeoArrowSchema(r.rawSchema, r.geoColumnIndices, rec) + r.geoSchemaBuilt = true + } + + return transformRecordForGeoArrow(rec, r.schema, r.geoColumnIndices) +} + func (r *ipcReaderAdapter) Next() bool { if r.closed || r.err != nil { return false @@ -178,7 +342,8 @@ func (r *ipcReaderAdapter) Next() bool { // Try to get next record from current reader if r.currentReader != nil && r.currentReader.Next() { - r.currentRecord = r.currentReader.RecordBatch() + rec := r.currentReader.RecordBatch() + r.currentRecord = r.handleGeoRecord(rec) r.currentRecord.Retain() return true } @@ -194,7 +359,8 @@ func (r *ipcReaderAdapter) Next() bool { // Try again with new reader if r.currentReader != nil && r.currentReader.Next() { - r.currentRecord = r.currentReader.RecordBatch() + rec := r.currentReader.RecordBatch() + r.currentRecord = r.handleGeoRecord(rec) r.currentRecord.Retain() return true } diff --git a/go/statement.go b/go/statement.go index f0567e6c..309a1297 100644 --- a/go/statement.go +++ b/go/statement.go @@ -133,7 +133,7 @@ func (s *statementImpl) ExecuteQuery(ctx context.Context) (array.RecordReader, i }() // Use the IPC stream interface (zero-copy) - reader, err := newIPCReaderAdapter(ctx, driverRows) + reader, err := newIPCReaderAdapter(ctx, driverRows, s.conn.useArrowNativeGeospatial) if err != nil { return nil, -1, s.ErrorHelper.Errorf(adbc.StatusInternal, "failed to create IPC reader adapter: %v", err) }