Skip to content

Commit 89b146e

Browse files
Implement retryable flag for databricks error
Added IsRetryable() and RetryableAfter() methods to DatabricksError. Updated client to insert an internal retryable error instance into the error chain when failing with a 503/429. DatabricksError uses the presence of this error to set the retryable flag and retryAfter value. Added test to driver_e2e_test.go to check that DatabricksError has the correct values for IsRetryable() and RetryAfter() Signed-off-by: Raymond Cypher <[email protected]>
1 parent a38f4bb commit 89b146e

File tree

6 files changed

+184
-28
lines changed

6 files changed

+184
-28
lines changed

driver_e2e_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,74 @@ func TestRetries(t *testing.T) {
421421
require.ErrorContains(t, err, "after 1 attempt(s)")
422422
})
423423

424+
t.Run("a 429 or 503 should result in a retryable error", func(t *testing.T) {
425+
426+
_ = logger.SetLogLevel("debug")
427+
state := &callState{}
428+
// load basic responses
429+
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
430+
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
431+
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
432+
433+
ts := getServer(state)
434+
435+
defer ts.Close()
436+
r, err := url.Parse(ts.URL)
437+
require.NoError(t, err)
438+
port, err := strconv.Atoi(r.Port())
439+
require.NoError(t, err)
440+
441+
connector, err := NewConnector(
442+
WithServerHostname("localhost"),
443+
WithHTTPPath("/429-5-retries"),
444+
WithPort(port),
445+
WithRetries(2, 10*time.Millisecond, 1*time.Second),
446+
)
447+
require.NoError(t, err)
448+
db := sql.OpenDB(connector)
449+
defer db.Close()
450+
451+
state.executeStatementResp = cli_service.TExecuteStatementResp{}
452+
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)
453+
454+
err = db.Ping()
455+
require.ErrorContains(t, err, "after 3 attempt(s)")
456+
457+
// The error chain should contain a databricks request error
458+
b := errors.Is(err, dbsqlerr.RequestError)
459+
require.True(t, b)
460+
var re dbsqlerr.DBRequestError
461+
b = errors.As(err, &re)
462+
require.True(t, b)
463+
require.NotNil(t, re)
464+
require.True(t, re.IsRetryable())
465+
require.Equal(t, "retry after header value", re.RetryAfter())
466+
467+
connector2, err := NewConnector(
468+
WithServerHostname("localhost"),
469+
WithHTTPPath("/503-5-retries"),
470+
WithPort(port),
471+
WithRetries(2, 10*time.Millisecond, 1*time.Second),
472+
)
473+
require.NoError(t, err)
474+
db2 := sql.OpenDB(connector2)
475+
defer db.Close()
476+
477+
state.executeStatementResp = cli_service.TExecuteStatementResp{}
478+
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)
479+
480+
err = db2.Ping()
481+
require.ErrorContains(t, err, "after 3 attempt(s)")
482+
483+
// The error chain should contain a databricks request error
484+
b = errors.Is(err, dbsqlerr.RequestError)
485+
require.True(t, b)
486+
b = errors.As(err, &re)
487+
require.True(t, b)
488+
require.NotNil(t, re)
489+
require.True(t, re.IsRetryable())
490+
})
491+
424492
}
425493

426494
// TODO: add tests for x-databricks headers

errors/errors.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ type DBError interface {
5959

6060
// Underlying causative error. May be nil.
6161
Cause() error
62+
63+
IsRetryable() bool
64+
65+
RetryAfter() string
6266
}
6367

6468
// An error that is caused by an invalid request.
@@ -70,8 +74,6 @@ type DBRequestError interface {
7074
// A fault that is caused by Databricks services
7175
type DBDriverError interface {
7276
DBError
73-
74-
IsRetryable() bool
7577
}
7678

7779
// Any error that occurs after the SQL statement has been accepted (e.g. SQL syntax error).

internal/client/client.go

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ func SprintGuid(bts []byte) string {
279279

280280
var retryableStatusCode = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}
281281

282+
func isRetryable(statusCode int) bool {
283+
for _, c := range retryableStatusCode {
284+
if c == statusCode {
285+
return true
286+
}
287+
}
288+
return false
289+
}
290+
282291
type Transport struct {
283292
Base *http.Transport
284293
Authr auth.Authenticator
@@ -321,14 +330,13 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
321330
if resp.StatusCode != http.StatusOK {
322331
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
323332
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
324-
for _, c := range retryableStatusCode {
325-
if c == resp.StatusCode {
326-
if terrmsg != "" {
327-
logger.Warn().Msg(terrmsg)
328-
}
329-
return resp, nil
333+
if isRetryable(resp.StatusCode) {
334+
if terrmsg != "" {
335+
logger.Warn().Msg(terrmsg)
330336
}
337+
return resp, nil
331338
}
339+
332340
if reason != "" {
333341
logger.Err(fmt.Errorf(reason)).Msg("non retryable error")
334342
return nil, errors.New(reason)
@@ -426,17 +434,25 @@ func errorHandler(resp *http.Response, err error, numTries int) (*http.Response,
426434
if err == nil {
427435
err = errors.New(fmt.Sprintf("request error after %d attempt(s)", numTries))
428436
}
429-
if resp != nil && resp.Header != nil {
430437

438+
if resp != nil {
439+
var orgid, reason, terrmsg, errmsg, retryAfter string
431440
// TODO @mattdeekay: convert these to specific error types
441+
if resp.Header != nil {
442+
orgid = resp.Header.Get("X-Databricks-Org-Id")
443+
reason = resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
444+
terrmsg = resp.Header.Get("X-Thriftserver-Error-Message")
445+
errmsg = resp.Header.Get("x-databricks-error-or-redirect-message")
446+
retryAfter = resp.Header.Get("Retry-After")
447+
// TODO note: need to see if there's other headers
448+
}
449+
msg := fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg)
432450

433-
orgid := resp.Header.Get("X-Databricks-Org-Id")
434-
reason := resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
435-
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
436-
errmsg := resp.Header.Get("x-databricks-error-or-redirect-message")
437-
// TODO note: need to see if there's other headers
451+
if isRetryable(resp.StatusCode) {
452+
err = dbsqlerrint.NewRetryableError(err, retryAfter)
453+
}
438454

439-
werr = errors.Wrapf(err, fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg))
455+
werr = dbsqlerrint.WrapErr(err, msg)
440456
} else {
441457
werr = err
442458
}
@@ -464,11 +480,8 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
464480
// 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
465481
// a Retry-After response header to indicate when the server is
466482
// available to start processing request from client.
467-
468-
for _, c := range retryableStatusCode {
469-
if c == resp.StatusCode {
470-
return true, nil
471-
}
483+
if isRetryable(resp.StatusCode) {
484+
return true, nil
472485
}
473486

474487
return false, nil

internal/errors/err.go

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@ import (
1010
"github.com/pkg/errors"
1111
)
1212

13+
// value to use with errors.Is() to determine if an error
14+
// chain contains a retryable error
15+
var RetryableError error = errors.New("Retryable Error")
16+
17+
// base databricks error
1318
type databricksError struct {
1419
err error
1520
correlationId string
1621
connectionId string
1722
errType string
23+
isRetryable bool
24+
retryAfter string
1825
}
1926

2027
var _ error = (*databricksError)(nil)
@@ -38,11 +45,25 @@ func newDatabricksError(ctx context.Context, msg string, err error) databricksEr
3845
err = errors.WithStack(err)
3946
}
4047

48+
// If the error chain contains an instance of retryableError
49+
// set the flag and retryAfter value.
50+
var retryable bool = false
51+
var retryAfter string
52+
if errors.Is(err, RetryableError) {
53+
retryable = true
54+
var re retryableError
55+
if ok := errors.As(err, &re); ok {
56+
retryAfter = re.retryAfter
57+
}
58+
}
59+
4160
return databricksError{
4261
err: err,
4362
correlationId: driverctx.CorrelationIdFromContext(ctx),
4463
connectionId: driverctx.ConnIdFromContext(ctx),
4564
errType: "unknown",
65+
isRetryable: retryable,
66+
retryAfter: retryAfter,
4667
}
4768
}
4869

@@ -75,10 +96,17 @@ func (e databricksError) Is(err error) bool {
7596
return err == dbsqlerr.DatabricksError
7697
}
7798

99+
func (e databricksError) IsRetryable() bool {
100+
return e.isRetryable
101+
}
102+
103+
func (e databricksError) RetryAfter() string {
104+
return e.retryAfter
105+
}
106+
78107
// driverError are issues with the driver or server, e.g. not supported operations, driver specific non-recoverable failures
79108
type driverError struct {
80109
databricksError
81-
isRetryable bool
82110
}
83111

84112
var _ dbsqlerr.DBDriverError = (*driverError)(nil)
@@ -91,14 +119,10 @@ func (e driverError) Unwrap() error {
91119
return e.err
92120
}
93121

94-
func (e driverError) IsRetryable() bool {
95-
return e.isRetryable
96-
}
97-
98122
func NewDriverError(ctx context.Context, msg string, err error) *driverError {
99123
dbErr := newDatabricksError(ctx, msg, err)
100124
dbErr.errType = "driver error"
101-
return &driverError{databricksError: dbErr, isRetryable: false}
125+
return &driverError{databricksError: dbErr}
102126
}
103127

104128
// requestError are errors caused by invalid requests, e.g. permission denied, warehouse not found
@@ -181,3 +205,37 @@ func WrapErrf(err error, format string, args ...interface{}) error {
181205
// wrap passed in error in errors with the formatted message and a stack trace
182206
return errors.Wrapf(err, format, args...)
183207
}
208+
209+
type retryableError struct {
210+
err error
211+
retryAfter string
212+
}
213+
214+
func (e retryableError) Is(err error) bool {
215+
return err == RetryableError
216+
}
217+
218+
func (e retryableError) Unwrap() error {
219+
return e.err
220+
}
221+
222+
func (e retryableError) Error() string {
223+
return fmt.Sprintf("databricks: retryableError: %s", e.err.Error())
224+
}
225+
226+
func (e retryableError) RetryAfter() string {
227+
return e.retryAfter
228+
}
229+
230+
func NewRetryableError(err error, retryAfter string) error {
231+
if err == nil {
232+
err = errors.New("")
233+
}
234+
235+
var st stackTracer
236+
if ok := errors.As(err, &st); !ok {
237+
err = errors.WithStack(err)
238+
}
239+
240+
return retryableError{err: err, retryAfter: retryAfter}
241+
}

internal/errors/err_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,13 @@ func TestDbSqlErrors(t *testing.T) {
9595
st := requestError.StackTrace()
9696
assert.NotNil(t, st)
9797

98+
// Get the underlying stackTracer instance, it should be
99+
// the original cause
98100
var str stackTracer
99101
ok := errors.As(requestError.Cause(), &str)
100102
assert.True(t, ok)
101-
assert.NotEqual(t, requestError, str)
103+
ss := str.StackTrace()
104+
assert.NotNil(t, ss)
102105
assert.Equal(t, cause, str)
103106

104107
cause = &boringError{}
@@ -107,9 +110,12 @@ func TestDbSqlErrors(t *testing.T) {
107110
st = requestError.StackTrace()
108111
assert.NotNil(t, st)
109112

113+
// Get the underlying stackTracer instance, it should not be
114+
// the original cause
110115
ok = errors.As(requestError.Cause(), &str)
111116
assert.True(t, ok)
112-
assert.NotEqual(t, requestError, str)
117+
ss = str.StackTrace()
118+
assert.NotNil(t, ss)
113119
assert.NotEqual(t, cause, str)
114120
})
115121

testserver.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ func (h *thriftHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4444
} else {
4545
h.count503_5_retries = 0
4646
}
47+
case "/429-5-retries":
48+
if h.count503_5_retries <= 5 {
49+
w.Header().Set("Retry-After", "retry after header value")
50+
w.WriteHeader(http.StatusServiceUnavailable)
51+
h.count503_5_retries++
52+
return
53+
} else {
54+
h.count503_5_retries = 0
55+
}
4756
}
4857

4958
thriftHandler := thrift.NewThriftHandlerFunc(h.processor, h.inPfactory, h.outPfactory)

0 commit comments

Comments
 (0)