From 5a113cb0a41a389dff18d56221221787a12dd0c0 Mon Sep 17 00:00:00 2001 From: Shivam Raj Date: Thu, 16 Oct 2025 00:53:55 +0530 Subject: [PATCH 1/2] Add metric view metadata support Signed-off-by: Shivam Raj --- connector.go | 28 ++++++++++++- connector_test.go | 33 +++++++++++++++ internal/config/config.go | 84 ++++++++++++++++++++++----------------- 3 files changed, 107 insertions(+), 38 deletions(-) diff --git a/connector.go b/connector.go index 21a5f178..6e685346 100644 --- a/connector.go +++ b/connector.go @@ -41,10 +41,28 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if err != nil { return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrThriftClient, err) } + + // Prepare session configuration + sessionParams := c.cfg.SessionParams + if c.cfg.EnableMetricViewMetadata { + // Make a copy of session params and add metric view metadata config + if sessionParams == nil { + sessionParams = make(map[string]string) + } else { + // Create a copy to avoid modifying the original + paramsCopy := make(map[string]string, len(sessionParams)+1) + for k, v := range sessionParams { + paramsCopy[k] = v + } + sessionParams = paramsCopy + } + sessionParams["spark.sql.thriftserver.metadata.metricview.enabled"] = "true" + } + protocolVersion := int64(c.cfg.ThriftProtocolVersion) session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ ClientProtocolI64: &protocolVersion, - Configuration: c.cfg.SessionParams, + Configuration: sessionParams, InitialNamespace: &cli_service.TNamespace{ CatalogName: catalogName, SchemaName: schemaName, @@ -265,6 +283,14 @@ func WithMaxDownloadThreads(numThreads int) ConnOption { } } +// WithEnableMetricViewMetadata enables metric view metadata support. Default is false. +// When enabled, adds spark.sql.thriftserver.metadata.metricview.enabled=true to session configuration. +func WithEnableMetricViewMetadata(enable bool) ConnOption { + return func(c *config.Config) { + c.EnableMetricViewMetadata = enable + } +} + // Setup of Oauth M2m authentication func WithClientCredentials(clientID, clientSecret string) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index 2e0e126b..57554b98 100644 --- a/connector_test.go +++ b/connector_test.go @@ -213,6 +213,39 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify) }) + + t.Run("Connector test WithEnableMetricViewMetadata enabled", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithEnableMetricViewMetadata(true), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.True(t, coni.cfg.EnableMetricViewMetadata) + }) + + t.Run("Connector test WithEnableMetricViewMetadata disabled by default", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.False(t, coni.cfg.EnableMetricViewMetadata) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 09946c91..67437a9c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -82,24 +82,25 @@ func (c *Config) DeepCopy() *Config { // UserConfig is the set of configurations exposed to users type UserConfig struct { - Protocol string - Host string // from databricks UI - Port int // from databricks UI - HTTPPath string // from databricks UI - Catalog string - Schema string - Authenticator auth.Authenticator - AccessToken string // from databricks UI - MaxRows int // max rows per page - QueryTimeout time.Duration // Timeout passed to server for query processing - UserAgentEntry string - Location *time.Location - SessionParams map[string]string - RetryWaitMin time.Duration - RetryWaitMax time.Duration - RetryMax int - Transport http.RoundTripper - UseLz4Compression bool + Protocol string + Host string // from databricks UI + Port int // from databricks UI + HTTPPath string // from databricks UI + Catalog string + Schema string + Authenticator auth.Authenticator + AccessToken string // from databricks UI + MaxRows int // max rows per page + QueryTimeout time.Duration // Timeout passed to server for query processing + UserAgentEntry string + Location *time.Location + SessionParams map[string]string + RetryWaitMin time.Duration + RetryWaitMax time.Duration + RetryMax int + Transport http.RoundTripper + UseLz4Compression bool + EnableMetricViewMetadata bool CloudFetchConfig } @@ -123,25 +124,26 @@ func (ucfg UserConfig) DeepCopy() UserConfig { } return UserConfig{ - Protocol: ucfg.Protocol, - Host: ucfg.Host, - Port: ucfg.Port, - HTTPPath: ucfg.HTTPPath, - Catalog: ucfg.Catalog, - Schema: ucfg.Schema, - Authenticator: ucfg.Authenticator, - AccessToken: ucfg.AccessToken, - MaxRows: ucfg.MaxRows, - QueryTimeout: ucfg.QueryTimeout, - UserAgentEntry: ucfg.UserAgentEntry, - Location: loccp, - SessionParams: sessionParams, - RetryWaitMin: ucfg.RetryWaitMin, - RetryWaitMax: ucfg.RetryWaitMax, - RetryMax: ucfg.RetryMax, - Transport: ucfg.Transport, - UseLz4Compression: ucfg.UseLz4Compression, - CloudFetchConfig: ucfg.CloudFetchConfig, + Protocol: ucfg.Protocol, + Host: ucfg.Host, + Port: ucfg.Port, + HTTPPath: ucfg.HTTPPath, + Catalog: ucfg.Catalog, + Schema: ucfg.Schema, + Authenticator: ucfg.Authenticator, + AccessToken: ucfg.AccessToken, + MaxRows: ucfg.MaxRows, + QueryTimeout: ucfg.QueryTimeout, + UserAgentEntry: ucfg.UserAgentEntry, + Location: loccp, + SessionParams: sessionParams, + RetryWaitMin: ucfg.RetryWaitMin, + RetryWaitMax: ucfg.RetryWaitMax, + RetryMax: ucfg.RetryMax, + Transport: ucfg.Transport, + UseLz4Compression: ucfg.UseLz4Compression, + EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata, + CloudFetchConfig: ucfg.CloudFetchConfig, } } @@ -272,6 +274,14 @@ func ParseDSN(dsn string) (UserConfig, error) { ucfg.MaxDownloadThreads = numThreads } + // Metric View Metadata parameter + if enableMetricViewMetadata, ok, err := params.extractAsBool("enableMetricViewMetadata"); ok { + if err != nil { + return UserConfig{}, err + } + ucfg.EnableMetricViewMetadata = enableMetricViewMetadata + } + // for timezone we do a case insensitive key match. // We use getNoCase because we want to leave timezone in the params so that it will also // be used as a session param. From 3edd436012beba48158163a456ae54bceaad059d Mon Sep 17 00:00:00 2001 From: Shivam Raj Date: Fri, 24 Oct 2025 17:20:17 +0530 Subject: [PATCH 2/2] Simplify session params handling Address PR review comments by simplifying the session configuration logic. Always create a new map and copy existing params regardless of whether metric view metadata is enabled. This makes the code cleaner and more consistent: - No nil check needed (ranging over nil map is safe in Go) - Consistent behavior in all cases - More readable code Resolves reviewer comments from @vikrantpuppala, @copilot-pull-request-reviewer, and @gopalldb Signed-off-by: Shivam Raj --- connector.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/connector.go b/connector.go index 6e685346..53908b4c 100644 --- a/connector.go +++ b/connector.go @@ -43,19 +43,12 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Prepare session configuration - sessionParams := c.cfg.SessionParams + sessionParams := make(map[string]string) + for k, v := range c.cfg.SessionParams { + sessionParams[k] = v + } + if c.cfg.EnableMetricViewMetadata { - // Make a copy of session params and add metric view metadata config - if sessionParams == nil { - sessionParams = make(map[string]string) - } else { - // Create a copy to avoid modifying the original - paramsCopy := make(map[string]string, len(sessionParams)+1) - for k, v := range sessionParams { - paramsCopy[k] = v - } - sessionParams = paramsCopy - } sessionParams["spark.sql.thriftserver.metadata.metricview.enabled"] = "true" }