Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 77 additions & 193 deletions go/bulk_ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,102 +15,64 @@
package databricks

import (
"bytes"
"context"
"database/sql/driver"
"fmt"
"strings"

"github.com/adbc-drivers/driverbase-go/driverbase"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
)

// executeIngest performs bulk insert using parameterized INSERT statements
func (s *statementImpl) executeIngest(ctx context.Context) (int64, error) {
if s.boundStream == nil {
return -1, s.ErrorHelper.Errorf(adbc.StatusInvalidState, "no data bound for ingestion")
}

defer func() {
s.boundStream.Release()
s.boundStream = nil
}()

opts := &s.bulkIngestOptions

tableName := buildTableName(opts.CatalogName, opts.SchemaName, opts.TableName)

if err := s.createTableIfNeeded(ctx, tableName, s.boundStream.Schema(), opts); err != nil {
return -1, err
}

insertSQL, err := buildInsertSQL(tableName, s.boundStream.Schema())
if err != nil {
return -1, err
}

totalRows := int64(0)
params := make([]driver.NamedValue, s.boundStream.Schema().NumFields())

for s.boundStream.Next() {
recordBatch := s.boundStream.RecordBatch()

for rowIdx := range int(recordBatch.NumRows()) {
// Extract Go values from Arrow columns
for colIdx := range int(recordBatch.NumCols()) {
arr := recordBatch.Column(colIdx)
val, err := extractGoValue(arr, rowIdx)
if err != nil {
return totalRows, s.ErrorHelper.Errorf(adbc.StatusInternal, "failed to extract go value: %v", err)
}
params[colIdx].Value = val
}

// Use ExecContext directly instead of PrepareContext because Databricks doesn't do server-side statement preparation
result, err := s.conn.conn.ExecContext(ctx, insertSQL, valuesToInterfaces(params)...)
if err != nil {
return totalRows, s.ErrorHelper.Errorf(adbc.StatusInternal, "failed to execute the query: %v", err)
}
// databricksBulkIngest implements driverbase.BulkIngestImpl for the
// Databricks Staging + COPY INTO pattern. It uploads Parquet files to a
// Unity Catalog Volume via the Databricks Files API, then uses COPY INTO
// to load them into the target table.
type databricksBulkIngest struct {
conn *connectionImpl
stagingClient *stagingClient
errorHelper *driverbase.ErrorHelper
options *driverbase.BulkIngestOptions
}

rows, _ := result.RowsAffected()
totalRows += rows
}
}
// pendingCopy tracks a file uploaded to staging that is ready for COPY INTO.
// It implements driverbase.BulkIngestPendingCopy.
type pendingCopy struct {
path string
rows int64
}

if err := s.boundStream.Err(); err != nil {
return totalRows, s.ErrorHelper.Errorf(adbc.StatusInternal, "stream error: %v", err)
}
func (p *pendingCopy) String() string { return p.path }
func (p *pendingCopy) Rows() int64 { return p.rows }

return totalRows, nil
// CreateSink returns an in-memory buffer for Parquet data to be written to.
func (bi *databricksBulkIngest) CreateSink(ctx context.Context, options *driverbase.BulkIngestOptions) (driverbase.BulkIngestSink, error) {
return &driverbase.BufferBulkIngestSink{}, nil
}

// createTableIfNeeded creates/drops table based on ingest mode
func (s *statementImpl) createTableIfNeeded(ctx context.Context, tableName string, schema *arrow.Schema, opts *driverbase.BulkIngestOptions) error {
switch opts.Mode {
case adbc.OptionValueIngestModeCreate:
return s.createTable(ctx, tableName, schema, false)
// CreateTable creates or drops/recreates the target table based on the
// specified table existence and missing behaviors.
func (bi *databricksBulkIngest) CreateTable(ctx context.Context, schema *arrow.Schema, ifTableExists driverbase.BulkIngestTableExistsBehavior, ifTableMissing driverbase.BulkIngestTableMissingBehavior) error {
tableName := buildTableName(bi.options.CatalogName, bi.options.SchemaName, bi.options.TableName)

case adbc.OptionValueIngestModeCreateAppend:
return s.createTable(ctx, tableName, schema, true)

case adbc.OptionValueIngestModeReplace:
if ifTableExists == driverbase.BulkIngestTableExistsDrop {
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
if _, err := s.conn.conn.ExecContext(ctx, dropSQL); err != nil {
return s.ErrorHelper.Errorf(adbc.StatusInternal, "failed to drop the table: %v", err)
if _, err := bi.conn.conn.ExecContext(ctx, dropSQL); err != nil {
return bi.errorHelper.Errorf(adbc.StatusInternal, "failed to drop table: %v", err)
}
return s.createTable(ctx, tableName, schema, false)

case adbc.OptionValueIngestModeAppend:
return nil
}

default:
return s.ErrorHelper.Errorf(adbc.StatusInvalidArgument, "invalid ingest mode: %s", opts.Mode)
if ifTableMissing == driverbase.BulkIngestTableMissingCreate {
ifNotExists := ifTableExists == driverbase.BulkIngestTableExistsIgnore
return bi.createTableDDL(ctx, tableName, schema, ifNotExists)
}

return nil
}

// createTable generates and executes CREATE TABLE DDL
func (s *statementImpl) createTable(ctx context.Context, tableName string, schema *arrow.Schema, ifNotExists bool) error {
// createTableDDL generates and executes CREATE TABLE DDL from an Arrow schema.
func (bi *databricksBulkIngest) createTableDDL(ctx context.Context, tableName string, schema *arrow.Schema, ifNotExists bool) error {
var sql strings.Builder
sql.WriteString("CREATE TABLE ")
if ifNotExists {
Expand All @@ -132,48 +94,59 @@ func (s *statementImpl) createTable(ctx context.Context, tableName string, schem
}
sql.WriteString(")")

_, err := s.conn.conn.ExecContext(ctx, sql.String())
_, err := bi.conn.conn.ExecContext(ctx, sql.String())
if err != nil {
return s.ErrorHelper.Errorf(adbc.StatusInternal, "failed to create table: %v", err)
return bi.errorHelper.Errorf(adbc.StatusInternal, "failed to create table: %v", err)
}
return nil
}

// buildInsertSQL generates parameterized INSERT statement
func buildInsertSQL(tableName string, schema *arrow.Schema) (string, error) {
var sql strings.Builder
// Upload uploads the Parquet data from a buffer to the staging volume.
func (bi *databricksBulkIngest) Upload(ctx context.Context, chunk driverbase.BulkIngestPendingUpload) (driverbase.BulkIngestPendingCopy, error) {
buf, ok := chunk.Data.(*driverbase.BufferBulkIngestSink)
if !ok {
return nil, bi.errorHelper.Errorf(adbc.StatusInternal, "unexpected sink type: %T", chunk.Data)
}

sql.WriteString("INSERT INTO ")
sql.WriteString(tableName)
sql.WriteString(" (")
path, err := bi.stagingClient.generateFileName()
if err != nil {
return nil, bi.errorHelper.Errorf(adbc.StatusInternal, "failed to generate staging file name: %v", err)
}

for i, field := range schema.Fields() {
if i > 0 {
sql.WriteString(", ")
}
sql.WriteString(quoteIdentifier(field.Name))
if err := bi.stagingClient.Upload(ctx, path, bytes.NewReader(buf.Bytes())); err != nil {
return nil, bi.errorHelper.Errorf(adbc.StatusIO, "failed to upload staging file: %v", err)
}

sql.WriteString(") VALUES (")
return &pendingCopy{path: path, rows: chunk.Rows}, nil
}

for i, field := range schema.Fields() {
if i > 0 {
sql.WriteString(", ")
}
// Copy executes COPY INTO to load a staged Parquet file into the target table.
func (bi *databricksBulkIngest) Copy(ctx context.Context, chunk driverbase.BulkIngestPendingCopy) error {
tableName := buildTableName(bi.options.CatalogName, bi.options.SchemaName, bi.options.TableName)

if field.Type.ID() == arrow.FIXED_SIZE_BINARY {
// Use UNHEX() to convert hex string to binary
sql.WriteString("UNHEX(?)")
} else {
sql.WriteString("?")
}
// The path from pendingCopy has the form "Volumes/catalog/schema/volume/prefix/file.parquet".
// COPY INTO expects the path with a leading slash: '/Volumes/...'
copySQL := fmt.Sprintf(
"COPY INTO %s FROM '/%s' FILEFORMAT = PARQUET",
tableName, chunk.String(),
)

_, err := bi.conn.conn.ExecContext(ctx, copySQL)
if err != nil {
return bi.errorHelper.Errorf(adbc.StatusInternal, "COPY INTO failed: %v", err)
}
return nil
}

sql.WriteString(")")
return sql.String(), nil
// Delete removes the staging file after it has been copied into the target table.
func (bi *databricksBulkIngest) Delete(ctx context.Context, chunk driverbase.BulkIngestPendingCopy) error {
if err := bi.stagingClient.Delete(ctx, chunk.String()); err != nil {
return bi.errorHelper.Errorf(adbc.StatusIO, "failed to delete staging file: %v", err)
}
return nil
}

// buildTableName constructs catalog.schema.table name
// buildTableName constructs a fully qualified catalog.schema.table name.
func buildTableName(catalog, schema, table string) string {
parts := []string{}
if catalog != "" {
Expand All @@ -186,102 +159,13 @@ func buildTableName(catalog, schema, table string) string {
return strings.Join(parts, ".")
}

// quoteIdentifier quotes a Databricks identifier with backticks
// quoteIdentifier quotes a Databricks identifier with backticks.
func quoteIdentifier(id string) string {
escaped := strings.ReplaceAll(id, "`", "``")
return fmt.Sprintf("`%s`", escaped)
}

// valuesToInterfaces converts driver.NamedValue slice to []any for ExecContext
func valuesToInterfaces(params []driver.NamedValue) []any {
result := make([]any, len(params))
for i, p := range params {
result[i] = p.Value
}
return result
}

// extractGoValue extracts a Go value from an Arrow array at the given index
func extractGoValue(arr arrow.Array, idx int) (any, error) {
if arr.IsNull(idx) {
return nil, nil
}

switch arr.DataType().ID() {
case arrow.BOOL:
return arr.(*array.Boolean).Value(idx), nil

case arrow.INT8:
return int64(arr.(*array.Int8).Value(idx)), nil
case arrow.INT16:
return int64(arr.(*array.Int16).Value(idx)), nil
case arrow.INT32:
return int64(arr.(*array.Int32).Value(idx)), nil
case arrow.INT64:
// https://github.com/databricks/databricks-sql-go/issues/315
// databricks-sql-go incorrectly maps int64 to SqlInteger (INT) instead of SqlBigInt (BIGINT)
// Pass as string to preserve full range
return fmt.Sprintf("%d", arr.(*array.Int64).Value(idx)), nil

case arrow.UINT8:
return int64(arr.(*array.Uint8).Value(idx)), nil
case arrow.UINT16:
return int64(arr.(*array.Uint16).Value(idx)), nil
case arrow.UINT32:
// https://github.com/databricks/databricks-sql-go/issues/315
// databricks-sql-go incorrectly maps int64 to SqlInteger (INT)
// Pass as string to preserve full uint32 range
return fmt.Sprintf("%d", arr.(*array.Uint32).Value(idx)), nil
case arrow.UINT64:
// Pass as string to preserve full uint64 range (may still overflow if > int64 max)
return fmt.Sprintf("%d", arr.(*array.Uint64).Value(idx)), nil

case arrow.FLOAT32:
return float64(arr.(*array.Float32).Value(idx)), nil
case arrow.FLOAT64:
// https://github.com/databricks/databricks-sql-go/issues/314
// databricks-sql-go has a bug where float64 is treated as SqlFloat instead of SqlDouble
// causing precision loss. Pass as string to preserve full precision.
val := arr.(*array.Float64).Value(idx)
return fmt.Sprintf("%.17g", val), nil

case arrow.STRING:
return arr.(*array.String).Value(idx), nil
case arrow.LARGE_STRING:
return arr.(*array.LargeString).Value(idx), nil
case arrow.STRING_VIEW:
return arr.(*array.StringView).Value(idx), nil

case arrow.BINARY:
return arr.(*array.Binary).Value(idx), nil
case arrow.LARGE_BINARY:
return arr.(*array.LargeBinary).Value(idx), nil
case arrow.BINARY_VIEW:
return arr.(*array.BinaryView).Value(idx), nil
case arrow.FIXED_SIZE_BINARY:
// Convert to hex string for use with UNHEX() SQL function
return fmt.Sprintf("%x", arr.(*array.FixedSizeBinary).Value(idx)), nil

case arrow.DATE32:
return arr.(*array.Date32).Value(idx).ToTime(), nil
case arrow.DATE64:
return arr.(*array.Date64).Value(idx).ToTime(), nil

case arrow.TIMESTAMP:
ts := arr.DataType().(*arrow.TimestampType)
return arr.(*array.Timestamp).Value(idx).ToTime(ts.Unit), nil

case arrow.DECIMAL128:
dec := arr.(*array.Decimal128)
// Return as string, databricks-sql-go will infer DECIMAL type
return dec.ValueStr(idx), nil

default:
return nil, fmt.Errorf("unsupported Arrow type: %s", arr.DataType())
}
}

// arrowTypeToDatabricksType maps Arrow types to Databricks DDL types for CREATE TABLE
// arrowTypeToDatabricksType maps Arrow types to Databricks DDL types for CREATE TABLE.
func arrowTypeToDatabricksType(dt arrow.DataType) string {
switch dt.ID() {
case arrow.BOOL:
Expand Down
Loading