Skip to content

Commit 8f18b05

Browse files
jayantsing-dbclaude
andcommitted
[PECOBLR-1172] Add configuration and protocol selection for SEA Phase 1
This commit implements the foundational configuration support for protocol selection between Thrift and REST protocols as part of Statement Execution API (SEA) Phase 1. Changes: - Added ExecutionProtocol field to UserConfig (defaults to "thrift") - Added WarehouseID field to UserConfig (required for REST protocol) - Added MaxPollInterval and PollBackoffMultiplier fields to Config - Implemented DSN parsing for protocol/executionProtocol parameters - Implemented DSN parsing for warehouse_id/warehouseId parameters - Added validation to require warehouse_id when using REST protocol - Updated DeepCopy methods to include new fields - Updated WithDefaults to set proper default values: - ExecutionProtocol: "thrift" (backward compatible) - MaxPollInterval: 60 seconds - PollBackoffMultiplier: 2.0 - Added comprehensive unit tests for all new functionality The implementation supports both snake_case and camelCase parameter names for user convenience (protocol/executionProtocol, warehouse_id/warehouseId). Related design doc: statement-execution-api-design-go.md 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 845334e commit 8f18b05

File tree

2 files changed

+262
-47
lines changed

2 files changed

+262
-47
lines changed

internal/config/config.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type Config struct {
3030
TLSConfig *tls.Config // nil disables TLS
3131
ArrowConfig
3232
PollInterval time.Duration
33+
MaxPollInterval time.Duration // maximum polling interval for exponential backoff
34+
PollBackoffMultiplier float64 // multiplier for exponential backoff
3335
ClientTimeout time.Duration // max time the http request can last
3436
PingTimeout time.Duration // max time allowed for ping
3537
CanUseMultipleCatalogs bool
@@ -68,6 +70,8 @@ func (c *Config) DeepCopy() *Config {
6870
TLSConfig: c.TLSConfig.Clone(),
6971
ArrowConfig: c.ArrowConfig.DeepCopy(),
7072
PollInterval: c.PollInterval,
73+
MaxPollInterval: c.MaxPollInterval,
74+
PollBackoffMultiplier: c.PollBackoffMultiplier,
7175
ClientTimeout: c.ClientTimeout,
7276
PingTimeout: c.PingTimeout,
7377
CanUseMultipleCatalogs: c.CanUseMultipleCatalogs,
@@ -101,6 +105,8 @@ type UserConfig struct {
101105
Transport http.RoundTripper
102106
UseLz4Compression bool
103107
EnableMetricViewMetadata bool
108+
ExecutionProtocol string // "thrift" (default) or "rest"
109+
WarehouseID string // required when ExecutionProtocol is "rest"
104110
CloudFetchConfig
105111
}
106112

@@ -143,6 +149,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
143149
Transport: ucfg.Transport,
144150
UseLz4Compression: ucfg.UseLz4Compression,
145151
EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata,
152+
ExecutionProtocol: ucfg.ExecutionProtocol,
153+
WarehouseID: ucfg.WarehouseID,
146154
CloudFetchConfig: ucfg.CloudFetchConfig,
147155
}
148156
}
@@ -176,6 +184,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
176184
if ucfg.RetryWaitMax == 0 {
177185
ucfg.RetryWaitMax = 30 * time.Second
178186
}
187+
if ucfg.ExecutionProtocol == "" {
188+
ucfg.ExecutionProtocol = "thrift"
189+
}
179190
ucfg.UseLz4Compression = false
180191
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
181192

@@ -189,6 +200,8 @@ func WithDefaults() *Config {
189200
TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12},
190201
ArrowConfig: ArrowConfig{}.WithDefaults(),
191202
PollInterval: 1 * time.Second,
203+
MaxPollInterval: 60 * time.Second,
204+
PollBackoffMultiplier: 2.0,
192205
ClientTimeout: 900 * time.Second,
193206
PingTimeout: 60 * time.Second,
194207
CanUseMultipleCatalogs: true,
@@ -282,6 +295,20 @@ func ParseDSN(dsn string) (UserConfig, error) {
282295
ucfg.EnableMetricViewMetadata = enableMetricViewMetadata
283296
}
284297

298+
// Execution protocol parameter (thrift or rest)
299+
if protocol, ok := params.extract("protocol"); ok {
300+
ucfg.ExecutionProtocol = protocol
301+
} else if protocol, ok := params.extract("executionProtocol"); ok {
302+
ucfg.ExecutionProtocol = protocol
303+
}
304+
305+
// Warehouse ID parameter (required for REST protocol)
306+
if warehouseID, ok := params.extract("warehouse_id"); ok {
307+
ucfg.WarehouseID = warehouseID
308+
} else if warehouseID, ok := params.extract("warehouseId"); ok {
309+
ucfg.WarehouseID = warehouseID
310+
}
311+
285312
// for timezone we do a case insensitive key match.
286313
// We use getNoCase because we want to leave timezone in the params so that it will also
287314
// be used as a session param.
@@ -298,6 +325,12 @@ func ParseDSN(dsn string) (UserConfig, error) {
298325
ucfg.SessionParams = sessionParams
299326
}
300327

328+
// Validate that warehouse_id is provided when using REST protocol
329+
if ucfg.ExecutionProtocol == "rest" && ucfg.WarehouseID == "" {
330+
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.ErrInvalidDSNFormat,
331+
errors.New("warehouse_id is required when using REST protocol (protocol=rest)"))
332+
}
333+
301334
return ucfg, err
302335
}
303336

0 commit comments

Comments
 (0)