Skip to content

Commit 8e30188

Browse files
feat(api): Allow API client to be configured with a retrying policy (#1284)
1 parent 0eda558 commit 8e30188

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

api/client.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"strings"
3030
"time"
3131

32+
"github.com/cenkalti/backoff/v4"
3233
"github.com/lacework/go-sdk/lwdomain"
3334
"github.com/pkg/errors"
3435
"go.uber.org/zap"
@@ -50,6 +51,7 @@ type Client struct {
5051
log *zap.Logger
5152
headers map[string]string
5253
callbacks LifecycleCallbacks
54+
retries *backoff.ExponentialBackOff
5355

5456
Policy *PolicyService
5557

@@ -191,6 +193,15 @@ func WithTimeout(timeout time.Duration) Option {
191193
})
192194
}
193195

196+
// WithRetries sets the retrying policy for API requests
197+
func WithRetries(retries *backoff.ExponentialBackOff) Option {
198+
return clientFunc(func(c *Client) error {
199+
c.log.Debug("setting up retrying policy", zap.Reflect("retries", retries))
200+
c.retries = retries
201+
return nil
202+
})
203+
}
204+
194205
// WithTransport changes the default transport to increase TLSHandshakeTimeout
195206
func WithTransport(transport *http.Transport) Option {
196207
return clientFunc(func(c *Client) error {
@@ -238,6 +249,11 @@ func (c *Client) URL() string {
238249
return c.baseURL.String()
239250
}
240251

252+
// Retries returns the retrying policy configured
253+
func (c *Client) Retries() *backoff.ExponentialBackOff {
254+
return c.retries
255+
}
256+
241257
// ValidAuth verifies that the client has valid authentication
242258
func (c *Client) ValidAuth() bool {
243259
return c.auth.token != ""

api/client_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"testing"
2626
"time"
2727

28+
"github.com/cenkalti/backoff/v4"
2829
"github.com/stretchr/testify/assert"
2930

3031
"github.com/lacework/go-sdk/api"
@@ -87,6 +88,7 @@ func TestNewClientWithOptions(t *testing.T) {
8788
api.WithExpirationTime(1800),
8889
api.WithApiV2(),
8990
api.WithTimeout(time.Minute*5),
91+
api.WithRetries(backoff.NewExponentialBackOff()),
9092
api.WithLogLevel("DEBUG"),
9193
api.WithHeader("User-Agent", "test-agent"),
9294
api.WithTokenFromKeys("KEY", "SECRET"), // this option has to be the last one
@@ -121,6 +123,7 @@ func TestCopyClientWithOptions(t *testing.T) {
121123
api.WithURL(fakeServer.URL()),
122124
api.WithExpirationTime(1800),
123125
api.WithTimeout(time.Minute*5),
126+
api.WithRetries(backoff.NewExponentialBackOff()),
124127
api.WithLogLevel("DEBUG"),
125128
api.WithHeader("User-Agent", "test-agent"),
126129
api.WithTokenFromKeys("KEY", "SECRET"), // this option has to be the last one
@@ -138,6 +141,7 @@ func TestCopyClientWithOptions(t *testing.T) {
138141
if assert.Nil(t, err) {
139142
assert.Equal(t, c.ApiVersion(), newExactClient.ApiVersion(), "copy client mismatch")
140143
assert.Equal(t, c.URL(), newExactClient.URL(), "copy client mismatch")
144+
assert.Equal(t, c.Retries(), newExactClient.Retries(), "copy retrying mistmatch")
141145
assert.True(t, newExactClient.ValidAuth())
142146
}
143147

@@ -151,12 +155,14 @@ func TestCopyClientWithOptions(t *testing.T) {
151155
api.WithExpirationTime(3600),
152156
api.WithApiV2(),
153157
api.WithTimeout(time.Minute*60), // LOL!
158+
api.WithRetries(nil),
154159
api.WithLogLevel("INFO"),
155160
api.WithOrgAccess(),
156161
)
157162
if assert.Nil(t, err) {
158163
assert.NotEqual(t, c.ApiVersion(), newModifiedClient.ApiVersion(), "copy modified client mismatch")
159164
assert.NotEqual(t, c.URL(), newModifiedClient.URL(), "copy modified client mismatch")
165+
assert.NotEqual(t, c.Retries(), newModifiedClient.Retries(), "copy modified retrying policy")
160166
assert.Equal(t, "v2", newModifiedClient.ApiVersion(), "copy modified API version should be v2")
161167
assert.Equal(t, "https://new.lacework.net/", newModifiedClient.URL(), "copy modified client mismatch")
162168
assert.True(t, newExactClient.ValidAuth())
@@ -232,3 +238,45 @@ func TestTLSHandshakeTimeout(t *testing.T) {
232238
_, err = clientWithTimeout.V2.AlertChannels.List()
233239
assert.NoError(t, err)
234240
}
241+
242+
func TestClientWithRetries(t *testing.T) {
243+
fakeServer := lacework.MockUnstartedServer()
244+
fakeServer.UseApiV2()
245+
apiPath := "AlertChannels"
246+
fakeServer.MockToken("TOKEN")
247+
fakeServer.Server.StartTLS()
248+
defer fakeServer.Close()
249+
250+
requestNumber := 0
251+
fakeServer.MockAPI(apiPath, func(w http.ResponseWriter, r *http.Request) {
252+
requestNumber += 1
253+
if requestNumber == 3 {
254+
w.WriteHeader(200)
255+
fmt.Fprintf(w, "{}")
256+
} else {
257+
w.WriteHeader(500)
258+
}
259+
})
260+
261+
transport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
262+
client, err := api.NewClient("test",
263+
api.WithApiV2(),
264+
api.WithToken("TOKEN"),
265+
api.WithURL(fakeServer.URL()),
266+
api.WithTransport(transport),
267+
)
268+
269+
_, err = client.V2.AlertChannels.List()
270+
assert.Error(t, err)
271+
272+
clientWithRetries, err := api.NewClient("test",
273+
api.WithApiV2(),
274+
api.WithToken("TOKEN"),
275+
api.WithURL(fakeServer.URL()),
276+
api.WithTransport(transport),
277+
api.WithRetries(backoff.NewExponentialBackOff()),
278+
)
279+
280+
_, err = clientWithRetries.V2.AlertChannels.List()
281+
assert.NoError(t, err)
282+
}

api/http.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"net/http"
2727
"net/url"
2828

29+
"github.com/cenkalti/backoff/v4"
2930
"go.uber.org/zap"
3031
)
3132

@@ -143,7 +144,15 @@ func (c *Client) RequestDecoder(method, path string, body io.Reader, v interface
143144
return err
144145
}
145146

146-
res, err := c.DoDecoder(request, v)
147+
var res *http.Response
148+
if c.retries != nil {
149+
err = backoff.Retry(func() error {
150+
res, err = c.DoDecoder(request, v)
151+
return err
152+
}, c.retries)
153+
} else {
154+
res, err = c.DoDecoder(request, v)
155+
}
147156
if err != nil {
148157
return err
149158
}

0 commit comments

Comments
 (0)