Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,21 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
if err != nil {
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrThriftClient, err)
}

// Prepare session configuration
sessionParams := make(map[string]string)
for k, v := range c.cfg.SessionParams {
sessionParams[k] = v
}

if c.cfg.EnableMetricViewMetadata {
sessionParams["spark.sql.thriftserver.metadata.metricview.enabled"] = "true"
}

protocolVersion := int64(c.cfg.ThriftProtocolVersion)
session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{
ClientProtocolI64: &protocolVersion,
Configuration: c.cfg.SessionParams,
Configuration: sessionParams,
InitialNamespace: &cli_service.TNamespace{
CatalogName: catalogName,
SchemaName: schemaName,
Expand Down Expand Up @@ -265,6 +276,14 @@ func WithMaxDownloadThreads(numThreads int) ConnOption {
}
}

// WithEnableMetricViewMetadata enables metric view metadata support. Default is false.
// When enabled, adds spark.sql.thriftserver.metadata.metricview.enabled=true to session configuration.
func WithEnableMetricViewMetadata(enable bool) ConnOption {
return func(c *config.Config) {
c.EnableMetricViewMetadata = enable
}
}

// Setup of Oauth M2m authentication
func WithClientCredentials(clientID, clientSecret string) ConnOption {
return func(c *config.Config) {
Expand Down
33 changes: 33 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,39 @@ func TestNewConnector(t *testing.T) {
require.True(t, ok)
require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify)
})

t.Run("Connector test WithEnableMetricViewMetadata enabled", func(t *testing.T) {
host := "databricks-host"
accessToken := "token"
httpPath := "http-path"
con, err := NewConnector(
WithServerHostname(host),
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
WithEnableMetricViewMetadata(true),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
assert.True(t, coni.cfg.EnableMetricViewMetadata)
})

t.Run("Connector test WithEnableMetricViewMetadata disabled by default", func(t *testing.T) {
host := "databricks-host"
accessToken := "token"
httpPath := "http-path"
con, err := NewConnector(
WithServerHostname(host),
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
assert.False(t, coni.cfg.EnableMetricViewMetadata)
})
}

type mockRoundTripper struct{}
Expand Down
84 changes: 47 additions & 37 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,25 @@ func (c *Config) DeepCopy() *Config {

// UserConfig is the set of configurations exposed to users
type UserConfig struct {
Protocol string
Host string // from databricks UI
Port int // from databricks UI
HTTPPath string // from databricks UI
Catalog string
Schema string
Authenticator auth.Authenticator
AccessToken string // from databricks UI
MaxRows int // max rows per page
QueryTimeout time.Duration // Timeout passed to server for query processing
UserAgentEntry string
Location *time.Location
SessionParams map[string]string
RetryWaitMin time.Duration
RetryWaitMax time.Duration
RetryMax int
Transport http.RoundTripper
UseLz4Compression bool
Protocol string
Host string // from databricks UI
Port int // from databricks UI
HTTPPath string // from databricks UI
Catalog string
Schema string
Authenticator auth.Authenticator
AccessToken string // from databricks UI
MaxRows int // max rows per page
QueryTimeout time.Duration // Timeout passed to server for query processing
UserAgentEntry string
Location *time.Location
SessionParams map[string]string
RetryWaitMin time.Duration
RetryWaitMax time.Duration
RetryMax int
Transport http.RoundTripper
UseLz4Compression bool
EnableMetricViewMetadata bool
CloudFetchConfig
}

Expand All @@ -123,25 +124,26 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
}

return UserConfig{
Protocol: ucfg.Protocol,
Host: ucfg.Host,
Port: ucfg.Port,
HTTPPath: ucfg.HTTPPath,
Catalog: ucfg.Catalog,
Schema: ucfg.Schema,
Authenticator: ucfg.Authenticator,
AccessToken: ucfg.AccessToken,
MaxRows: ucfg.MaxRows,
QueryTimeout: ucfg.QueryTimeout,
UserAgentEntry: ucfg.UserAgentEntry,
Location: loccp,
SessionParams: sessionParams,
RetryWaitMin: ucfg.RetryWaitMin,
RetryWaitMax: ucfg.RetryWaitMax,
RetryMax: ucfg.RetryMax,
Transport: ucfg.Transport,
UseLz4Compression: ucfg.UseLz4Compression,
CloudFetchConfig: ucfg.CloudFetchConfig,
Protocol: ucfg.Protocol,
Host: ucfg.Host,
Port: ucfg.Port,
HTTPPath: ucfg.HTTPPath,
Catalog: ucfg.Catalog,
Schema: ucfg.Schema,
Authenticator: ucfg.Authenticator,
AccessToken: ucfg.AccessToken,
MaxRows: ucfg.MaxRows,
QueryTimeout: ucfg.QueryTimeout,
UserAgentEntry: ucfg.UserAgentEntry,
Location: loccp,
SessionParams: sessionParams,
RetryWaitMin: ucfg.RetryWaitMin,
RetryWaitMax: ucfg.RetryWaitMax,
RetryMax: ucfg.RetryMax,
Transport: ucfg.Transport,
UseLz4Compression: ucfg.UseLz4Compression,
EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata,
CloudFetchConfig: ucfg.CloudFetchConfig,
}
}

Expand Down Expand Up @@ -272,6 +274,14 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.MaxDownloadThreads = numThreads
}

// Metric View Metadata parameter
if enableMetricViewMetadata, ok, err := params.extractAsBool("enableMetricViewMetadata"); ok {
if err != nil {
return UserConfig{}, err
}
ucfg.EnableMetricViewMetadata = enableMetricViewMetadata
}

// for timezone we do a case insensitive key match.
// We use getNoCase because we want to leave timezone in the params so that it will also
// be used as a session param.
Expand Down
Loading