Skip to content

Commit 350ea35

Browse files
Implemented IsRetryable() and RetryAfter() for DatabricksError (#119)
Added IsRetryable and RetryAfter functions to DBError interface. Added an internal error type for retryable errors. Updated client to insert a retryable error instance into the error chain.
2 parents 36b12cd + 014b68e commit 350ea35

File tree

12 files changed

+260
-55
lines changed

12 files changed

+260
-55
lines changed

connector.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6767
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)
6868
_, err := conn.ExecContext(ctx, setStmt, []driver.NamedValue{})
6969
if err != nil {
70-
return nil, err
70+
return nil, dbsqlerrint.NewExecutionError(ctx, fmt.Sprintf("error setting session param: %s", setStmt), err, nil)
7171
}
7272
log.Info().Msgf("set session parameter: param=%s value=%s", k, v)
7373
}

doc.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,15 @@ Example usage:
175175
if errors.Is(err, dbsqlerr.ExecutionError) {
176176
var execErr dbsqlerr.DBExecutionError
177177
if ok := errors.As(err, &execError); ok {
178-
fmt.Printf("%s, corrId: %s, connId: %s, queryId: %s, sqlState: %s",
178+
fmt.Printf("%s, corrId: %s, connId: %s, queryId: %s, sqlState: %s, isRetryable: %t, retryAfter: %f seconds",
179179
execErr.Error(),
180180
execErr.CorrelationId(),
181181
execErr.ConnectionId(),
182182
execErr.QueryId(),
183-
execErr.SqlState())
183+
execErr.SqlState(),
184+
execErr.IsRetryable(),
185+
execErr.RetryAfter().Seconds(),
186+
)
184187
}
185188
}
186189
...

driver_e2e_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,74 @@ func TestRetries(t *testing.T) {
421421
require.ErrorContains(t, err, "after 1 attempt(s)")
422422
})
423423

424+
t.Run("a 429 or 503 should result in a retryable error", func(t *testing.T) {
425+
426+
_ = logger.SetLogLevel("debug")
427+
state := &callState{}
428+
// load basic responses
429+
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
430+
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
431+
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
432+
433+
ts := getServer(state)
434+
435+
defer ts.Close()
436+
r, err := url.Parse(ts.URL)
437+
require.NoError(t, err)
438+
port, err := strconv.Atoi(r.Port())
439+
require.NoError(t, err)
440+
441+
connector, err := NewConnector(
442+
WithServerHostname("localhost"),
443+
WithHTTPPath("/429-5-retries"),
444+
WithPort(port),
445+
WithRetries(2, 10*time.Millisecond, 1*time.Second),
446+
)
447+
require.NoError(t, err)
448+
db := sql.OpenDB(connector)
449+
defer db.Close()
450+
451+
state.executeStatementResp = cli_service.TExecuteStatementResp{}
452+
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)
453+
454+
err = db.Ping()
455+
require.ErrorContains(t, err, "after 3 attempt(s)")
456+
457+
// The error chain should contain a databricks request error
458+
b := errors.Is(err, dbsqlerr.RequestError)
459+
require.True(t, b)
460+
var re dbsqlerr.DBRequestError
461+
b = errors.As(err, &re)
462+
require.True(t, b)
463+
require.NotNil(t, re)
464+
require.True(t, re.IsRetryable())
465+
require.Equal(t, 12*time.Second, re.RetryAfter())
466+
467+
connector2, err := NewConnector(
468+
WithServerHostname("localhost"),
469+
WithHTTPPath("/503-5-retries"),
470+
WithPort(port),
471+
WithRetries(2, 10*time.Millisecond, 1*time.Second),
472+
)
473+
require.NoError(t, err)
474+
db2 := sql.OpenDB(connector2)
475+
defer db.Close()
476+
477+
state.executeStatementResp = cli_service.TExecuteStatementResp{}
478+
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)
479+
480+
err = db2.Ping()
481+
require.ErrorContains(t, err, "after 3 attempt(s)")
482+
483+
// The error chain should contain a databricks request error
484+
b = errors.Is(err, dbsqlerr.RequestError)
485+
require.True(t, b)
486+
b = errors.As(err, &re)
487+
require.True(t, b)
488+
require.NotNil(t, re)
489+
require.True(t, re.IsRetryable())
490+
})
491+
424492
}
425493

426494
// TODO: add tests for x-databricks headers

errors/errors.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package errors
22

3-
import "github.com/pkg/errors"
3+
import (
4+
"time"
5+
6+
"github.com/pkg/errors"
7+
)
48

59
// Error messages
610
const (
@@ -59,6 +63,10 @@ type DBError interface {
5963

6064
// Underlying causative error. May be nil.
6165
Cause() error
66+
67+
IsRetryable() bool
68+
69+
RetryAfter() time.Duration
6270
}
6371

6472
// An error that is caused by an invalid request.
@@ -70,8 +78,6 @@ type DBRequestError interface {
7078
// A fault that is caused by Databricks services
7179
type DBDriverError interface {
7280
DBError
73-
74-
IsRetryable() bool
7581
}
7682

7783
// Any error that occurs after the SQL statement has been accepted (e.g. SQL syntax error).

examples/error/main.go

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,34 @@ func main() {
7171
}
7272
ctx = driverctx.NewContextWithQueryIdCallback(ctx, queryIdCallback)
7373

74-
rows, err1 := db.QueryContext(ctx, `select * from default.intervals`)
75-
fmt.Printf("conn Id: %s, query Id: %s\n", connId, queryId)
74+
var rows *sql.Rows
75+
maxRetries := 3
76+
shouldTry := true
77+
78+
// We want to retry running the query if an error is returned where IsRetryable() is true up
79+
// to the maximum number of retries.
80+
for i := 0; i < maxRetries && shouldTry; i++ {
81+
var err1 error
82+
var wait time.Duration
83+
84+
rows, err1 = db.QueryContext(ctx, `select * from default.Intervals`)
85+
86+
// Check if the error is retryable and if there is a wait before
87+
// trying again.
88+
if shouldTry, wait = isRetryable(err1); shouldTry {
89+
fmt.Printf("query failed, retrying after %f seconds", wait.Seconds())
90+
time.Sleep(wait)
91+
} else {
92+
// handle the error, which may be nil
93+
handleErr(err1)
94+
}
95+
}
7696

77-
handleErr(err1)
97+
// At this point the query completed successfully
7898
defer rows.Close()
7999

100+
fmt.Printf("conn Id: %s, query Id: %s\n", connId, queryId)
101+
80102
colNames, _ := rows.Columns()
81103
for i := range colNames {
82104
fmt.Printf("%d: %s\n", i, colNames[i])
@@ -91,6 +113,8 @@ func main() {
91113

92114
}
93115

116+
// If the error is not nil extract/ databricks specific error information and then
117+
// terminate the program.
94118
func handleErr(err error) {
95119
if err == nil {
96120
return
@@ -155,3 +179,14 @@ func getQueryIdAndSQLState(err error) (queryId, sqlState string) {
155179

156180
return
157181
}
182+
183+
// Use errors.As to extract a DBError from the error chain and return the associated
184+
// values for isRetryable and retryAfter
185+
func isRetryable(err error) (isRetryable bool, retryAfter time.Duration) {
186+
var dbErr dbsqlerr.DBError
187+
if errors.As(err, &dbErr) {
188+
isRetryable = dbErr.IsRetryable()
189+
retryAfter = dbErr.RetryAfter()
190+
}
191+
return
192+
}

internal/client/client.go

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ func SprintGuid(bts []byte) string {
279279

280280
var retryableStatusCode = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}
281281

282+
func isRetryable(statusCode int) bool {
283+
for _, c := range retryableStatusCode {
284+
if c == statusCode {
285+
return true
286+
}
287+
}
288+
return false
289+
}
290+
282291
type Transport struct {
283292
Base *http.Transport
284293
Authr auth.Authenticator
@@ -321,14 +330,13 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
321330
if resp.StatusCode != http.StatusOK {
322331
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
323332
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
324-
for _, c := range retryableStatusCode {
325-
if c == resp.StatusCode {
326-
if terrmsg != "" {
327-
logger.Warn().Msg(terrmsg)
328-
}
329-
return resp, nil
333+
if isRetryable(resp.StatusCode) {
334+
if terrmsg != "" {
335+
logger.Warn().Msg(terrmsg)
330336
}
337+
return resp, nil
331338
}
339+
332340
if reason != "" {
333341
logger.Err(fmt.Errorf(reason)).Msg("non retryable error")
334342
return nil, errors.New(reason)
@@ -426,17 +434,25 @@ func errorHandler(resp *http.Response, err error, numTries int) (*http.Response,
426434
if err == nil {
427435
err = errors.New(fmt.Sprintf("request error after %d attempt(s)", numTries))
428436
}
429-
if resp != nil && resp.Header != nil {
430437

438+
if resp != nil {
439+
var orgid, reason, terrmsg, errmsg, retryAfter string
431440
// TODO @mattdeekay: convert these to specific error types
441+
if resp.Header != nil {
442+
orgid = resp.Header.Get("X-Databricks-Org-Id")
443+
reason = resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
444+
terrmsg = resp.Header.Get("X-Thriftserver-Error-Message")
445+
errmsg = resp.Header.Get("x-databricks-error-or-redirect-message")
446+
retryAfter = resp.Header.Get("Retry-After")
447+
// TODO note: need to see if there's other headers
448+
}
449+
msg := fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg)
432450

433-
orgid := resp.Header.Get("X-Databricks-Org-Id")
434-
reason := resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
435-
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
436-
errmsg := resp.Header.Get("x-databricks-error-or-redirect-message")
437-
// TODO note: need to see if there's other headers
451+
if isRetryable(resp.StatusCode) {
452+
err = dbsqlerrint.NewRetryableError(err, retryAfter)
453+
}
438454

439-
werr = errors.Wrapf(err, fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg))
455+
werr = dbsqlerrint.WrapErr(err, msg)
440456
} else {
441457
werr = err
442458
}
@@ -464,11 +480,8 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
464480
// 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
465481
// a Retry-After response header to indicate when the server is
466482
// available to start processing request from client.
467-
468-
for _, c := range retryableStatusCode {
469-
if c == resp.StatusCode {
470-
return true, nil
471-
}
483+
if isRetryable(resp.StatusCode) {
484+
return true, nil
472485
}
473486

474487
return false, nil

internal/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ import (
99
"strings"
1010
"time"
1111

12+
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1213
"github.com/pkg/errors"
1314

1415
"github.com/databricks/databricks-sql-go/auth"
1516
"github.com/databricks/databricks-sql-go/auth/noop"
1617
"github.com/databricks/databricks-sql-go/auth/pat"
17-
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1818
"github.com/databricks/databricks-sql-go/internal/cli_service"
1919
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2020
"github.com/databricks/databricks-sql-go/logger"

0 commit comments

Comments
 (0)