Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ _tmp*
.vscode/
__debug_bin
.DS_Store

.claude/
85 changes: 44 additions & 41 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,24 @@ func TestNewConnector(t *testing.T) {
CloudFetchSpeedThresholdMbps: 0.1,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
CloudFetchConfig: expectedCloudFetchConfig,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
ExecutionProtocol: "thrift",
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand Down Expand Up @@ -97,18 +98,19 @@ func TestNewConnector(t *testing.T) {
CloudFetchSpeedThresholdMbps: 0.1,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
CloudFetchConfig: expectedCloudFetchConfig,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
ExecutionProtocol: "thrift",
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.UserConfig = expectedUserConfig
Expand Down Expand Up @@ -139,18 +141,19 @@ func TestNewConnector(t *testing.T) {
CloudFetchSpeedThresholdMbps: 0.1,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
CloudFetchConfig: expectedCloudFetchConfig,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
ExecutionProtocol: "thrift",
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand Down
33 changes: 33 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type Config struct {
TLSConfig *tls.Config // nil disables TLS
ArrowConfig
PollInterval time.Duration
MaxPollInterval time.Duration // maximum polling interval for exponential backoff
PollBackoffMultiplier float64 // multiplier for exponential backoff
ClientTimeout time.Duration // max time the http request can last
PingTimeout time.Duration // max time allowed for ping
CanUseMultipleCatalogs bool
Expand Down Expand Up @@ -68,6 +70,8 @@ func (c *Config) DeepCopy() *Config {
TLSConfig: c.TLSConfig.Clone(),
ArrowConfig: c.ArrowConfig.DeepCopy(),
PollInterval: c.PollInterval,
MaxPollInterval: c.MaxPollInterval,
PollBackoffMultiplier: c.PollBackoffMultiplier,
ClientTimeout: c.ClientTimeout,
PingTimeout: c.PingTimeout,
CanUseMultipleCatalogs: c.CanUseMultipleCatalogs,
Expand Down Expand Up @@ -101,6 +105,8 @@ type UserConfig struct {
Transport http.RoundTripper
UseLz4Compression bool
EnableMetricViewMetadata bool
ExecutionProtocol string // "thrift" (default) or "rest"
WarehouseID string // required when ExecutionProtocol is "rest"
CloudFetchConfig
}

Expand Down Expand Up @@ -143,6 +149,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
Transport: ucfg.Transport,
UseLz4Compression: ucfg.UseLz4Compression,
EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata,
ExecutionProtocol: ucfg.ExecutionProtocol,
WarehouseID: ucfg.WarehouseID,
CloudFetchConfig: ucfg.CloudFetchConfig,
}
}
Expand Down Expand Up @@ -176,6 +184,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
if ucfg.RetryWaitMax == 0 {
ucfg.RetryWaitMax = 30 * time.Second
}
if ucfg.ExecutionProtocol == "" {
ucfg.ExecutionProtocol = "thrift"
}
ucfg.UseLz4Compression = false
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()

Expand All @@ -189,6 +200,8 @@ func WithDefaults() *Config {
TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12},
ArrowConfig: ArrowConfig{}.WithDefaults(),
PollInterval: 1 * time.Second,
MaxPollInterval: 60 * time.Second,
PollBackoffMultiplier: 2.0,
ClientTimeout: 900 * time.Second,
PingTimeout: 60 * time.Second,
CanUseMultipleCatalogs: true,
Expand Down Expand Up @@ -282,6 +295,20 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.EnableMetricViewMetadata = enableMetricViewMetadata
}

// Execution protocol parameter (thrift or rest)
if protocol, ok := params.extract("protocol"); ok {
ucfg.ExecutionProtocol = protocol
} else if protocol, ok := params.extract("executionProtocol"); ok {
ucfg.ExecutionProtocol = protocol
}

// Warehouse ID parameter (required for REST protocol)
if warehouseID, ok := params.extract("warehouse_id"); ok {
ucfg.WarehouseID = warehouseID
} else if warehouseID, ok := params.extract("warehouseId"); ok {
ucfg.WarehouseID = warehouseID
}

// for timezone we do a case insensitive key match.
// We use getNoCase because we want to leave timezone in the params so that it will also
// be used as a session param.
Expand All @@ -298,6 +325,12 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.SessionParams = sessionParams
}

// Validate that warehouse_id is provided when using REST protocol
if ucfg.ExecutionProtocol == "rest" && ucfg.WarehouseID == "" {
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.ErrInvalidDSNFormat,
errors.New("warehouse_id is required when using REST protocol (protocol=rest)"))
}

return ucfg, err
}

Expand Down
Loading
Loading