Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions go/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type databaseImpl struct {
needsRefresh bool // Whether we need to re-initialize

// Connection parameters
uri string
serverHostname string
httpPath string
accessToken string
Expand Down Expand Up @@ -160,20 +161,29 @@ func (d *databaseImpl) resolveConnectionOptions() ([]dbsql.ConnOption, error) {
}

func (d *databaseImpl) initializeConnectionPool(ctx context.Context) (*sql.DB, error) {
opts, err := d.resolveConnectionOptions()
var db *sql.DB

if err != nil {
return nil, err
}
// Use URI if provided
if d.uri != "" {
var err error
db, err = sql.Open("databricks", d.uri)
if err != nil {
return nil, err
}
} else {
opts, err := d.resolveConnectionOptions()
if err != nil {
return nil, err
}

connector, err := dbsql.NewConnector(opts...)
connector, err := dbsql.NewConnector(opts...)
if err != nil {
return nil, err
}

if err != nil {
return nil, err
db = sql.OpenDB(connector)
}

db := sql.OpenDB(connector)

// Test the connection
if err := db.PingContext(ctx); err != nil {
err = errors.Join(db.Close())
Expand Down Expand Up @@ -238,6 +248,8 @@ func (d *databaseImpl) Close() error {

func (d *databaseImpl) GetOption(key string) (string, error) {
switch key {
case adbc.OptionKeyURI:
return d.uri, nil
case OptionServerHostname:
return d.serverHostname, nil
case OptionHTTPPath:
Expand Down Expand Up @@ -288,6 +300,25 @@ func (d *databaseImpl) GetOption(key string) (string, error) {
func (d *databaseImpl) SetOptions(options map[string]string) error {
// We need to re-initialize the db/connection pool if options change
d.needsRefresh = true

hasURI := false
hasOtherOptions := false

if _, ok := options[adbc.OptionKeyURI]; ok {
hasURI = true
}

if len(options) > 1 || (len(options) == 1 && !hasURI) {
hasOtherOptions = true
}

if hasURI && hasOtherOptions {
return adbc.Error{
Code: adbc.StatusInvalidArgument,
Msg: "cannot specify both URI and individual connection options",
}
}

for k, v := range options {
err := d.SetOption(k, v)
if err != nil {
Expand All @@ -301,6 +332,16 @@ func (d *databaseImpl) SetOption(key, value string) error {
// We need to re-initialize the db/connection pool if options change
d.needsRefresh = true
switch key {
case adbc.OptionKeyURI:
// Strip the databricks:// scheme since databricks-sql-go expects raw DSN format
if after, ok := strings.CutPrefix(value, "databricks://"); ok {
d.uri = after
Comment on lines +337 to +338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually...do we need to also URL-decode the string? Or does the Databricks driver do that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to URL-decode the string, the Databricks driver handles it.
https://github.com/databricks/databricks-sql-go/blob/main/internal/config/config.go#L210

} else {
return adbc.Error{
Code: adbc.StatusInvalidArgument,
Msg: fmt.Sprintf("invalid URI scheme: expected 'databricks://', got '%s'", value),
}
}
case OptionServerHostname:
d.serverHostname = value
case OptionHTTPPath:
Expand Down
63 changes: 59 additions & 4 deletions go/docs/databricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ dbc install databricks

## Connecting

TODO: This section once https://github.com/apache/arrow-adbc/pull/3771 is merged here.

To connect, edit the `uri` option below to match your environment and run the following:

```python
Expand All @@ -43,7 +41,7 @@ from adbc_driver_manager import dbapi
conn = dbapi.connect(
driver="databricks",
db_kwargs = {
"uri": "TODO"
"uri": "databricks://token:dapi1234abcd5678efgh@dbc-a1b2345c-d6e7.cloud.databricks.com:443/sql/protocolv1/o/1234567890123456/1234-567890-abcdefgh"
}
)
```
Expand All @@ -52,7 +50,64 @@ Note: The example above is for Python using the [adbc-driver-manager](https://py

### Connection String Format

TODO: This section once https://github.com/apache/arrow-adbc/pull/3771 is merged here.
Databricks's URI syntax supports three primary forms:

1. Databricks personal access token authentication:

```
databricks://token:<personal-access-token>@<server-hostname>:<port-number>/<http-path>?<param1=value1>&<param2=value2>
```

Components:
- `scheme`: `databricks://` (required)
- `<personal-access-token>`: (required) Databricks personal access token.
- `<server-hostname>`: (required) Server Hostname value.
- `port-number`: (required) Port value, which is typically 443.
- `http-path`: (required) HTTP Path value.
- Query params: Databricks connection attributes. For complete list of optional parameters, see [Databricks Optional Parameters](https://docs.databricks.com/aws/en/dev-tools/go-sql-driver#optional-parameters)


2. OAuth user-to-machine (U2M) authentication:

```
databricks://<server-hostname>:<port-number>/<http-path>?authType=OauthU2M&<param1=value1>&<param2=value2>
```

Components:
- `scheme`: `databricks://` (required)
- `<server-hostname>`: (required) Server Hostname value.
- `port-number`: (required) Port value, which is typically 443.
- `http-path`: (required) HTTP Path value.
- `authType=OauthU2M`: (required) Specifies OAuth user-to-machine authentication.
- Query params: Additional Databricks connection attributes. For complete list of optional parameters, see [Databricks Optional Parameters](https://docs.databricks.com/aws/en/dev-tools/go-sql-driver#optional-parameters)

3. OAuth machine-to-machine (M2M) authentication:

```
databricks://<server-hostname>:<port-number>/<http-path>?authType=OAuthM2M&clientID=<client-id>&clientSecret=<client-secret>&<param1=value1>&<param2=value2>
```

Components:
- `scheme`: `databricks://` (required)
- `<server-hostname>`: (required) Server Hostname value.
- `port-number`: (required) Port value, which is typically 443.
- `http-path`: (required) HTTP Path value.
- `authType=OAuthM2M`: (required) Specifies OAuth machine-to-machine authentication.
- `<client-id>`: (required) Service principal's UUID or Application ID value.
- `<client-secret>`: (required) Secret value for the service principal's OAuth secret.
- Query params: Additional Databricks connection attributes. For complete list of optional parameters, see [Databricks Optional Parameters](https://docs.databricks.com/aws/en/dev-tools/go-sql-driver#optional-parameters)

This follows the [Databricks SQL Driver for Go](https://docs.databricks.com/aws/en/dev-tools/go-sql-driver#connect-with-a-dsn-connection-string) format with the addition of the `databricks://` scheme.

:::{note}
Reserved characters in URI elements must be URI-encoded. For example, `@` becomes `%40`.
:::

Examples:

- `databricks://token:dapi1234abcd5678efgh@dbc-a1b2345c-d6e7.cloud.databricks.com:443/sql/protocolv1/o/1234567890123456/1234-567890-abcdefgh`
- `databricks://myworkspace.cloud.databricks.com:443/sql/1.0/warehouses/abc123def456?authType=OauthU2M`
- `databricks://myworkspace.cloud.databricks.com:443/sql/1.0/warehouses/abc123def456?authType=OAuthM2M&clientID=12345678-1234-1234-1234-123456789012&clientSecret=mysecret123`

## Feature & Type Support

Expand Down
54 changes: 53 additions & 1 deletion go/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
"strings"
"testing"

"github.com/adbc-drivers/databricks/go/databricks"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/databricks"
"github.com/apache/arrow-adbc/go/adbc/validation"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/memory"
Expand All @@ -47,6 +47,7 @@ type DatabricksQuirks struct {
httpPath string
token string
port string
uri string // The URI to use for the test if set
}

func (d *DatabricksQuirks) SetupDriver(t *testing.T) adbc.Driver {
Expand All @@ -59,6 +60,12 @@ func (d *DatabricksQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) {
}

func (d *DatabricksQuirks) DatabaseOptions() map[string]string {
if d.uri != "" {
return map[string]string{
adbc.OptionKeyURI: d.uri,
}
}

opts := map[string]string{
databricks.OptionServerHostname: d.hostname,
databricks.OptionHTTPPath: d.httpPath,
Expand Down Expand Up @@ -327,6 +334,18 @@ func withQuirks(t *testing.T, fn func(*DatabricksQuirks)) {
fn(q)
}

func withQuirksURI(t *testing.T, fn func(*DatabricksQuirks)) {
uri := os.Getenv("DATABRICKS_URI")
if uri == "" {
t.Skip("DATABRICKS_URI not defined, skipping URI tests")
}

q := &DatabricksQuirks{
uri: uri,
}
fn(q)
}

func TestValidation(t *testing.T) {
withQuirks(t, func(q *DatabricksQuirks) {
suite.Run(t, &validation.DatabaseTests{Quirks: q})
Expand All @@ -341,6 +360,39 @@ func TestDatabricks(t *testing.T) {
})
}

func TestDatabricksWithURI(t *testing.T) {
withQuirksURI(t, func(q *DatabricksQuirks) {
drv := q.SetupDriver(t)
defer q.TearDownDriver(t, drv)

db, err := drv.NewDatabase(q.DatabaseOptions())
require.NoError(t, err)
defer validation.CheckedClose(t, db)

ctx := context.Background()
cnxn, err := db.Open(ctx)
require.NoError(t, err)
defer validation.CheckedClose(t, cnxn)

stmt, err := cnxn.NewStatement()
require.NoError(t, err)
defer validation.CheckedClose(t, stmt)

require.NoError(t, stmt.SetSqlQuery("SELECT 1 as test_col"))
rdr, _, err := stmt.ExecuteQuery(ctx)
require.NoError(t, err)
defer rdr.Release()

assert.True(t, rdr.Next())
rec := rdr.RecordBatch()
assert.Equal(t, int64(1), rec.NumRows())
assert.Equal(t, int64(1), rec.NumCols())
assert.Equal(t, "test_col", rec.ColumnName(0))
assert.False(t, rdr.Next())
require.NoError(t, rdr.Err())
})
}

// ---- Additional Tests --------------------

type DatabricksTests struct {
Expand Down
Loading