Skip to content

Commit 191b9a7

Browse files
authored
Add sqs client retryer (#1099)
* Add sqs client retryer
1 parent ab11c08 commit 191b9a7

File tree

3 files changed

+167
-3
lines changed

3 files changed

+167
-3
lines changed

cmd/node-termination-handler.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package main
1616
import (
1717
"context"
1818
"fmt"
19-
"github.com/aws/aws-node-termination-handler/pkg/monitor/asglifecycle"
2019
"os"
2120
"os/signal"
2221
"strings"
@@ -31,6 +30,7 @@ import (
3130
"github.com/aws/aws-node-termination-handler/pkg/interruptioneventstore"
3231
"github.com/aws/aws-node-termination-handler/pkg/logging"
3332
"github.com/aws/aws-node-termination-handler/pkg/monitor"
33+
"github.com/aws/aws-node-termination-handler/pkg/monitor/asglifecycle"
3434
"github.com/aws/aws-node-termination-handler/pkg/monitor/rebalancerecommendation"
3535
"github.com/aws/aws-node-termination-handler/pkg/monitor/scheduledevent"
3636
"github.com/aws/aws-node-termination-handler/pkg/monitor/spotitn"
@@ -43,7 +43,6 @@ import (
4343
"github.com/aws/aws-sdk-go/aws/session"
4444
"github.com/aws/aws-sdk-go/service/autoscaling"
4545
"github.com/aws/aws-sdk-go/service/ec2"
46-
"github.com/aws/aws-sdk-go/service/sqs"
4746
"github.com/rs/zerolog"
4847
"github.com/rs/zerolog/log"
4948
"k8s.io/apimachinery/pkg/util/wait"
@@ -223,7 +222,7 @@ func main() {
223222
QueueURL: nthConfig.QueueURL,
224223
InterruptionChan: interruptionChan,
225224
CancelChan: cancelChan,
226-
SQS: sqs.New(sess),
225+
SQS: sqsevent.GetSqsClient(sess),
227226
ASG: autoscaling.New(sess),
228227
EC2: ec2.New(sess),
229228
BeforeCompleteLifecycleAction: func() { <-time.After(completeLifecycleActionDelay) },
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
4+
// not use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
// express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
package sqsevent
15+
16+
import (
17+
"strings"
18+
"time"
19+
20+
"github.com/aws/aws-sdk-go/aws"
21+
"github.com/aws/aws-sdk-go/aws/client"
22+
"github.com/aws/aws-sdk-go/aws/request"
23+
"github.com/aws/aws-sdk-go/aws/session"
24+
"github.com/aws/aws-sdk-go/service/sqs"
25+
)
26+
27+
type SqsRetryer struct {
28+
client.DefaultRetryer
29+
}
30+
31+
func (r SqsRetryer) ShouldRetry(req *request.Request) bool {
32+
return r.DefaultRetryer.ShouldRetry(req) ||
33+
(req.Error != nil && strings.Contains(req.Error.Error(), "connection reset"))
34+
}
35+
36+
func GetSqsClient(sess *session.Session) *sqs.SQS {
37+
return sqs.New(sess, &aws.Config{
38+
Retryer: SqsRetryer{
39+
DefaultRetryer: client.DefaultRetryer{
40+
// Monitor continuously monitors SQS for events every 2 seconds
41+
NumMaxRetries: client.DefaultRetryerMaxNumRetries,
42+
MinRetryDelay: client.DefaultRetryerMinRetryDelay,
43+
MaxRetryDelay: 1200 * time.Millisecond,
44+
MinThrottleDelay: client.DefaultRetryerMinThrottleDelay,
45+
MaxThrottleDelay: 1200 * time.Millisecond,
46+
},
47+
},
48+
})
49+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
4+
// not use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
// express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
package sqsevent_test
15+
16+
import (
17+
"fmt"
18+
"net"
19+
"testing"
20+
"time"
21+
22+
"github.com/aws/aws-node-termination-handler/pkg/monitor/sqsevent"
23+
h "github.com/aws/aws-node-termination-handler/pkg/test"
24+
"github.com/aws/aws-sdk-go/aws"
25+
"github.com/aws/aws-sdk-go/aws/awserr"
26+
"github.com/aws/aws-sdk-go/aws/client"
27+
"github.com/aws/aws-sdk-go/aws/request"
28+
"github.com/aws/aws-sdk-go/aws/session"
29+
)
30+
31+
type temporaryError struct {
32+
error
33+
temp bool
34+
}
35+
36+
func TestGetSqsClient(t *testing.T) {
37+
retryer := getSqsRetryer(t)
38+
39+
h.Equals(t, client.DefaultRetryerMaxNumRetries, retryer.NumMaxRetries)
40+
h.Equals(t, time.Duration(1200*time.Millisecond), retryer.MaxRetryDelay)
41+
}
42+
43+
func TestShouldRetry(t *testing.T) {
44+
retryer := getSqsRetryer(t)
45+
46+
testCases := []struct {
47+
name string
48+
req *request.Request
49+
shouldRetry bool
50+
}{
51+
{
52+
name: "AWS throttling error",
53+
req: &request.Request{
54+
Error: awserr.New("ThrottlingException", "Rate exceeded", nil),
55+
},
56+
shouldRetry: true,
57+
},
58+
{
59+
name: "AWS validation error",
60+
req: &request.Request{
61+
Error: awserr.New("ValidationError", "Invalid parameter", nil),
62+
},
63+
shouldRetry: false,
64+
},
65+
{
66+
name: "read connection reset by peer error",
67+
req: &request.Request{
68+
Error: &temporaryError{
69+
error: &net.OpError{
70+
Op: "read",
71+
Err: fmt.Errorf("read: connection reset by peer"),
72+
},
73+
temp: false,
74+
}},
75+
shouldRetry: true,
76+
},
77+
{
78+
name: "read unknown error",
79+
req: &request.Request{
80+
Error: &temporaryError{
81+
error: &net.OpError{
82+
Op: "read",
83+
Err: fmt.Errorf("read unknown error"),
84+
},
85+
temp: false,
86+
}},
87+
shouldRetry: false,
88+
},
89+
}
90+
91+
for _, tc := range testCases {
92+
t.Run(tc.name, func(t *testing.T) {
93+
result := retryer.ShouldRetry(tc.req)
94+
h.Equals(t, tc.shouldRetry, result)
95+
})
96+
}
97+
}
98+
99+
func getSqsRetryer(t *testing.T) sqsevent.SqsRetryer {
100+
sess, err := session.NewSession(&aws.Config{
101+
Region: aws.String("us-east-1"),
102+
})
103+
h.Ok(t, err)
104+
105+
sqsClient := sqsevent.GetSqsClient(sess)
106+
h.Assert(t, sqsClient.Client.Config.Region != nil, "Region should not be nil")
107+
h.Equals(t, "us-east-1", *sqsClient.Client.Config.Region)
108+
109+
retryer, ok := sqsClient.Client.Config.Retryer.(sqsevent.SqsRetryer)
110+
h.Assert(t, ok, "Retryer should be of type SqsRetryer")
111+
return retryer
112+
}
113+
114+
func (e *temporaryError) Temporary() bool {
115+
return e.temp
116+
}

0 commit comments

Comments
 (0)