diff --git a/connector.go b/connector.go index 21a5f178..53908b4c 100644 --- a/connector.go +++ b/connector.go @@ -41,10 +41,21 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if err != nil { return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrThriftClient, err) } + + // Prepare session configuration + sessionParams := make(map[string]string) + for k, v := range c.cfg.SessionParams { + sessionParams[k] = v + } + + if c.cfg.EnableMetricViewMetadata { + sessionParams["spark.sql.thriftserver.metadata.metricview.enabled"] = "true" + } + protocolVersion := int64(c.cfg.ThriftProtocolVersion) session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ ClientProtocolI64: &protocolVersion, - Configuration: c.cfg.SessionParams, + Configuration: sessionParams, InitialNamespace: &cli_service.TNamespace{ CatalogName: catalogName, SchemaName: schemaName, @@ -265,6 +276,14 @@ func WithMaxDownloadThreads(numThreads int) ConnOption { } } +// WithEnableMetricViewMetadata enables metric view metadata support. Default is false. +// When enabled, adds spark.sql.thriftserver.metadata.metricview.enabled=true to session configuration. +func WithEnableMetricViewMetadata(enable bool) ConnOption { + return func(c *config.Config) { + c.EnableMetricViewMetadata = enable + } +} + // Setup of Oauth M2m authentication func WithClientCredentials(clientID, clientSecret string) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index 2e0e126b..57554b98 100644 --- a/connector_test.go +++ b/connector_test.go @@ -213,6 +213,39 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify) }) + + t.Run("Connector test WithEnableMetricViewMetadata enabled", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithEnableMetricViewMetadata(true), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.True(t, coni.cfg.EnableMetricViewMetadata) + }) + + t.Run("Connector test WithEnableMetricViewMetadata disabled by default", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.False(t, coni.cfg.EnableMetricViewMetadata) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 09946c91..67437a9c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -82,24 +82,25 @@ func (c *Config) DeepCopy() *Config { // UserConfig is the set of configurations exposed to users type UserConfig struct { - Protocol string - Host string // from databricks UI - Port int // from databricks UI - HTTPPath string // from databricks UI - Catalog string - Schema string - Authenticator auth.Authenticator - AccessToken string // from databricks UI - MaxRows int // max rows per page - QueryTimeout time.Duration // Timeout passed to server for query processing - UserAgentEntry string - Location *time.Location - SessionParams map[string]string - RetryWaitMin time.Duration - RetryWaitMax time.Duration - RetryMax int - Transport http.RoundTripper - UseLz4Compression bool + Protocol string + Host string // from databricks UI + Port int // from databricks UI + HTTPPath string // from databricks UI + Catalog string + Schema string + Authenticator auth.Authenticator + AccessToken string // from databricks UI + MaxRows int // max rows per page + QueryTimeout time.Duration // Timeout passed to server for query processing + UserAgentEntry string + Location *time.Location + SessionParams map[string]string + RetryWaitMin time.Duration + RetryWaitMax time.Duration + RetryMax int + Transport http.RoundTripper + UseLz4Compression bool + EnableMetricViewMetadata bool CloudFetchConfig } @@ -123,25 +124,26 @@ func (ucfg UserConfig) DeepCopy() UserConfig { } return UserConfig{ - Protocol: ucfg.Protocol, - Host: ucfg.Host, - Port: ucfg.Port, - HTTPPath: ucfg.HTTPPath, - Catalog: ucfg.Catalog, - Schema: ucfg.Schema, - Authenticator: ucfg.Authenticator, - AccessToken: ucfg.AccessToken, - MaxRows: ucfg.MaxRows, - QueryTimeout: ucfg.QueryTimeout, - UserAgentEntry: ucfg.UserAgentEntry, - Location: loccp, - SessionParams: sessionParams, - RetryWaitMin: ucfg.RetryWaitMin, - RetryWaitMax: ucfg.RetryWaitMax, - RetryMax: ucfg.RetryMax, - Transport: ucfg.Transport, - UseLz4Compression: ucfg.UseLz4Compression, - CloudFetchConfig: ucfg.CloudFetchConfig, + Protocol: ucfg.Protocol, + Host: ucfg.Host, + Port: ucfg.Port, + HTTPPath: ucfg.HTTPPath, + Catalog: ucfg.Catalog, + Schema: ucfg.Schema, + Authenticator: ucfg.Authenticator, + AccessToken: ucfg.AccessToken, + MaxRows: ucfg.MaxRows, + QueryTimeout: ucfg.QueryTimeout, + UserAgentEntry: ucfg.UserAgentEntry, + Location: loccp, + SessionParams: sessionParams, + RetryWaitMin: ucfg.RetryWaitMin, + RetryWaitMax: ucfg.RetryWaitMax, + RetryMax: ucfg.RetryMax, + Transport: ucfg.Transport, + UseLz4Compression: ucfg.UseLz4Compression, + EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata, + CloudFetchConfig: ucfg.CloudFetchConfig, } } @@ -272,6 +274,14 @@ func ParseDSN(dsn string) (UserConfig, error) { ucfg.MaxDownloadThreads = numThreads } + // Metric View Metadata parameter + if enableMetricViewMetadata, ok, err := params.extractAsBool("enableMetricViewMetadata"); ok { + if err != nil { + return UserConfig{}, err + } + ucfg.EnableMetricViewMetadata = enableMetricViewMetadata + } + // for timezone we do a case insensitive key match. // We use getNoCase because we want to leave timezone in the params so that it will also // be used as a session param.