Skip to content

Commit a980671

Browse files
authored
Merge pull request #912 from mackerelio/update-aws-sdk-go-v2
replace to aws-sdk-go-v2
2 parents 398f905 + e178383 commit a980671

File tree

6 files changed

+159
-96
lines changed

6 files changed

+159
-96
lines changed

check-aws-cloudwatch-logs/lib/check-aws-cloudwatch-logs.go

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ import (
1313
"strings"
1414
"time"
1515

16-
"github.com/aws/aws-sdk-go/aws"
17-
"github.com/aws/aws-sdk-go/aws/session"
18-
"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
19-
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
16+
"github.com/aws/aws-sdk-go-v2/aws"
17+
"github.com/aws/aws-sdk-go-v2/config"
18+
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
2019
"github.com/jessevdk/go-flags"
2120

2221
"github.com/mackerelio/checkers"
@@ -50,15 +49,15 @@ func Do() {
5049
}
5150

5251
type awsCloudwatchLogsPlugin struct {
53-
Service cloudwatchlogsiface.CloudWatchLogsAPI
52+
Service cloudwatchlogs.FilterLogEventsAPIClient
5453
StateFile string
5554
*logOpts
5655
}
5756

58-
func newCloudwatchLogsPlugin(opts *logOpts, args []string) (*awsCloudwatchLogsPlugin, error) {
57+
func newCloudwatchLogsPlugin(ctx context.Context, opts *logOpts, args []string) (*awsCloudwatchLogsPlugin, error) {
5958
var err error
6059
p := &awsCloudwatchLogsPlugin{logOpts: opts}
61-
p.Service, err = createService(opts)
60+
p.Service, err = createService(ctx, opts)
6261
if err != nil {
6362
return nil, err
6463
}
@@ -94,20 +93,22 @@ func getStateFile(stateDir, logGroupName, logStreamNamePrefix string, args []str
9493
)
9594
}
9695

97-
func createAWSConfig(opts *logOpts) *aws.Config {
98-
conf := aws.NewConfig()
96+
func createCloudwatchlogsOptions(opts *logOpts) (optFns []func(*cloudwatchlogs.Options)) {
9997
if opts.MaxRetries > 0 {
100-
return conf.WithMaxRetries(opts.MaxRetries)
98+
optFns = append(optFns, func(o *cloudwatchlogs.Options) {
99+
o.RetryMaxAttempts = opts.MaxRetries
100+
})
101101
}
102-
return conf
102+
return
103103
}
104104

105-
func createService(opts *logOpts) (*cloudwatchlogs.CloudWatchLogs, error) {
106-
sess, err := session.NewSession()
105+
func createService(ctx context.Context, opts *logOpts) (*cloudwatchlogs.Client, error) {
106+
cfg, err := config.LoadDefaultConfig(ctx)
107107
if err != nil {
108108
return nil, err
109109
}
110-
return cloudwatchlogs.New(sess, createAWSConfig(opts)), nil
110+
111+
return cloudwatchlogs.NewFromConfig(cfg, createCloudwatchlogsOptions(opts)...), nil
111112
}
112113

113114
type logState struct {
@@ -134,9 +135,12 @@ func (p *awsCloudwatchLogsPlugin) collect(ctx context.Context, now time.Time) ([
134135
if p.LogStreamNamePrefix != "" {
135136
input.LogStreamNamePrefix = aws.String(p.LogStreamNamePrefix)
136137
}
137-
err = p.Service.FilterLogEventsPages(input, func(output *cloudwatchlogs.FilterLogEventsOutput, lastPage bool) bool {
138-
if ctx.Err() != nil {
139-
return false
138+
139+
paginator := cloudwatchlogs.NewFilterLogEventsPaginator(p.Service, input)
140+
for paginator.HasMorePages() {
141+
output, err := paginator.NextPage(ctx)
142+
if err != nil {
143+
return nil, err
140144
}
141145
for _, event := range output.Events {
142146
messages = append(messages, *event.Message)
@@ -145,18 +149,11 @@ func (p *awsCloudwatchLogsPlugin) collect(ctx context.Context, now time.Time) ([
145149
}
146150
}
147151
s.NextToken = output.NextToken
148-
if lastPage {
149-
s.NextToken = nil
150-
}
151-
err = p.saveState(s)
152-
if err != nil {
153-
return false
152+
153+
if err = p.saveState(s); err != nil {
154+
return nil, err
154155
}
155156
time.Sleep(150 * time.Millisecond)
156-
return true
157-
})
158-
if err != nil {
159-
return nil, err
160157
}
161158
return messages, nil
162159
}
@@ -220,7 +217,7 @@ func run(ctx context.Context, args []string) *checkers.Checker {
220217
if err != nil {
221218
os.Exit(1)
222219
}
223-
p, err := newCloudwatchLogsPlugin(opts, args)
220+
p, err := newCloudwatchLogsPlugin(ctx, opts, args)
224221
if err != nil {
225222
return checkers.Unknown(fmt.Sprint(err))
226223
}

check-aws-cloudwatch-logs/lib/check-aws-cloudwatch-logs_test.go

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,44 @@ import (
77
"context"
88
"encoding/json"
99
"fmt"
10-
"io/ioutil"
1110
"os"
11+
"path/filepath"
12+
"strconv"
1213
"testing"
1314
"time"
1415

1516
"github.com/jessevdk/go-flags"
1617
"github.com/mackerelio/checkers"
1718
"github.com/stretchr/testify/assert"
1819

19-
"github.com/aws/aws-sdk-go/aws"
20-
"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
21-
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
20+
"github.com/aws/aws-sdk-go-v2/aws"
21+
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
22+
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
2223
)
2324

2425
type mockAWSCloudWatchLogsClient struct {
25-
cloudwatchlogsiface.CloudWatchLogsAPI
26+
cloudwatchlogs.FilterLogEventsAPIClient
2627
outputs []*cloudwatchlogs.FilterLogEventsOutput
2728
}
2829

29-
func (c *mockAWSCloudWatchLogsClient) FilterLogEventsPages(input *cloudwatchlogs.FilterLogEventsInput, fn func(*cloudwatchlogs.FilterLogEventsOutput, bool) bool) error {
30-
for i, output := range c.outputs {
31-
lastPage := i == len(c.outputs)-1
32-
if !fn(output, lastPage) {
33-
break
34-
}
30+
func (c *mockAWSCloudWatchLogsClient) FilterLogEvents(ctx context.Context, input *cloudwatchlogs.FilterLogEventsInput, _ ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.FilterLogEventsOutput, error) {
31+
if ctx.Err() != nil {
32+
return nil, ctx.Err()
33+
}
34+
35+
var pageNo = 0
36+
if input.NextToken != nil {
37+
pageNo, _ = strconv.Atoi(*input.NextToken)
3538
}
36-
return nil
39+
return c.outputs[pageNo], nil
3740
}
3841

39-
func createMockService() cloudwatchlogsiface.CloudWatchLogsAPI {
42+
func createMockService() cloudwatchlogs.FilterLogEventsAPIClient {
4043
return &mockAWSCloudWatchLogsClient{
4144
outputs: []*cloudwatchlogs.FilterLogEventsOutput{
4245
{
4346
NextToken: aws.String("1"),
44-
Events: []*cloudwatchlogs.FilteredLogEvent{
47+
Events: []types.FilteredLogEvent{
4548
{
4649
EventId: aws.String("event-id-0"),
4750
Message: aws.String("message-0"),
@@ -56,7 +59,7 @@ func createMockService() cloudwatchlogsiface.CloudWatchLogsAPI {
5659
},
5760
{
5861
NextToken: aws.String("2"),
59-
Events: []*cloudwatchlogs.FilteredLogEvent{
62+
Events: []types.FilteredLogEvent{
6063
{
6164
EventId: aws.String("event-id-2"),
6265
Message: aws.String("message-2"),
@@ -75,7 +78,7 @@ func createMockService() cloudwatchlogsiface.CloudWatchLogsAPI {
7578
},
7679
},
7780
{
78-
Events: []*cloudwatchlogs.FilteredLogEvent{
81+
Events: []types.FilteredLogEvent{
7982
{
8083
EventId: aws.String("event-id-5"),
8184
Message: aws.String("message-5"),
@@ -88,13 +91,11 @@ func createMockService() cloudwatchlogsiface.CloudWatchLogsAPI {
8891
}
8992

9093
func Test_cloudwatchLogsPlugin_collect(t *testing.T) {
91-
file, _ := ioutil.TempFile("", "check-cloudwatch-logs-test-collect")
92-
os.Remove(file.Name())
93-
file.Close()
94-
defer os.Remove(file.Name())
94+
stateFile := filepath.Join(t.TempDir(), "check-cloudwatch-logs-test-collect")
95+
9596
p := &awsCloudwatchLogsPlugin{
9697
Service: createMockService(),
97-
StateFile: file.Name(),
98+
StateFile: stateFile,
9899
logOpts: &logOpts{
99100
LogGroupName: "test-group",
100101
},
@@ -104,7 +105,7 @@ func Test_cloudwatchLogsPlugin_collect(t *testing.T) {
104105
messages, err := p.collect(context.Background(), time.Unix(0, 0))
105106
assert.Equal(t, err, nil, "err should be nil")
106107
assert.Equal(t, len(messages), 6)
107-
cnt, _ := ioutil.ReadFile(file.Name())
108+
cnt, _ := os.ReadFile(stateFile)
108109
var s logState
109110
json.NewDecoder(bytes.NewReader(cnt)).Decode(&s)
110111
assert.Equal(t, s, logState{StartTime: aws.Int64(5 + 1)})
@@ -115,7 +116,7 @@ func Test_cloudwatchLogsPlugin_collect(t *testing.T) {
115116
cancel()
116117

117118
messages, err := p.collect(ctx, time.Unix(0, 0))
118-
assert.Equal(t, err, nil, "err should be nil")
119+
assert.NotEqual(t, err, nil, "err should be someting")
119120
assert.Equal(t, len(messages), 0)
120121
})
121122
}
@@ -221,25 +222,32 @@ func Test_cloudwatchLogsPlugin_options(t *testing.T) {
221222
}
222223
}
223224

224-
func Test_createAWSConfig(t *testing.T) {
225+
func Test_createCloudwatchlogsOptions(t *testing.T) {
225226
tests := []struct {
226-
opts *logOpts
227-
want *aws.Config
227+
opts *logOpts
228+
want cloudwatchlogs.Options
229+
length int
228230
}{
229231
{
230-
opts: &logOpts{MaxRetries: 0},
231-
want: aws.NewConfig(),
232+
opts: &logOpts{MaxRetries: 0},
233+
length: 0,
232234
},
233235
{
234-
opts: &logOpts{MaxRetries: 1},
235-
want: aws.NewConfig().WithMaxRetries(1),
236+
opts: &logOpts{MaxRetries: 1},
237+
want: cloudwatchlogs.Options{RetryMaxAttempts: 1},
238+
length: 1,
236239
},
237240
}
238241

239242
for i, tt := range tests {
240243
t.Run(fmt.Sprintf("case:%d", i), func(t *testing.T) {
241-
res := createAWSConfig(tt.opts)
242-
assert.Equal(t, tt.want, res)
244+
res := createCloudwatchlogsOptions(tt.opts)
245+
assert.Equal(t, len(res), tt.length)
246+
if tt.length > 0 {
247+
opts := cloudwatchlogs.Options{}
248+
res[0](&opts)
249+
assert.Equal(t, tt.want, opts)
250+
}
243251
})
244252
}
245253
}

check-aws-sqs-queue-size/lib/check-aws-sqs-queue-size.go

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
11
package checkawssqsqueuesize
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"os"
8+
"os/signal"
79
"strconv"
810

9-
"github.com/aws/aws-sdk-go/aws"
10-
"github.com/aws/aws-sdk-go/aws/credentials"
11-
"github.com/aws/aws-sdk-go/aws/session"
12-
"github.com/aws/aws-sdk-go/service/sqs"
11+
"github.com/aws/aws-sdk-go-v2/aws"
12+
"github.com/aws/aws-sdk-go-v2/config"
13+
"github.com/aws/aws-sdk-go-v2/credentials"
14+
"github.com/aws/aws-sdk-go-v2/service/sqs"
15+
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
1316
"github.com/jessevdk/go-flags"
1417
"github.com/mackerelio/checkers"
1518
)
1619

20+
// overwritten with syscall.SIGTERM on unix environment (see check-log_unix.go)
21+
var defaultSignal = os.Interrupt
22+
1723
// Do the plugin
1824
func Do() {
19-
ckr := run(os.Args[1:])
25+
ctx, stop := signal.NotifyContext(context.Background(), defaultSignal)
26+
defer stop()
27+
28+
ckr := run(ctx, os.Args[1:])
2029
ckr.Name = "SQSQueueSize"
2130
ckr.Exit()
2231
}
@@ -32,65 +41,68 @@ var opts struct {
3241

3342
const sqsAttributeOfQueueSize = "ApproximateNumberOfMessages"
3443

35-
func createService(region, awsAccessKeyID, awsSecretAccessKey string) (*sqs.SQS, error) {
36-
sess, err := session.NewSession()
37-
if err != nil {
38-
return nil, err
39-
}
40-
41-
config := aws.NewConfig()
44+
func createService(ctx context.Context, region, awsAccessKeyID, awsSecretAccessKey string) (*sqs.Client, error) {
45+
var opts []func(*config.LoadOptions) error
4246
if awsAccessKeyID != "" && awsSecretAccessKey != "" {
43-
config = config.WithCredentials(credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, ""))
47+
opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(awsAccessKeyID, awsSecretAccessKey, "")))
4448
}
4549
if region != "" {
46-
config = config.WithRegion(region)
50+
opts = append(opts, config.WithRegion(region))
4751
}
48-
return sqs.New(sess, config), nil
52+
53+
cfg, err := config.LoadDefaultConfig(ctx, opts...)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
return sqs.NewFromConfig(cfg), nil
4959
}
5060

51-
func getSqsQueueSize(region, awsAccessKeyID, awsSecretAccessKey, queueName string) (int, error) {
52-
sqsClient, err := createService(region, awsAccessKeyID, awsSecretAccessKey)
61+
func getSqsQueueSize(ctx context.Context, region, awsAccessKeyID, awsSecretAccessKey, queueName string) (int, error) {
62+
sqsClient, err := createService(ctx, region, awsAccessKeyID, awsSecretAccessKey)
5363
if err != nil {
5464
return -1, err
5565
}
5666

5767
// Get queue url
58-
q, err := sqsClient.GetQueueUrl(&sqs.GetQueueUrlInput{
68+
q, err := sqsClient.GetQueueUrl(ctx, &sqs.GetQueueUrlInput{
5969
QueueName: aws.String(queueName),
6070
})
6171
if err != nil {
6272
return -1, err
6373
}
6474

6575
// Get queue attribute
66-
attr, err := sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
67-
QueueUrl: q.QueueUrl,
68-
AttributeNames: []*string{aws.String(sqsAttributeOfQueueSize)},
76+
attr, err := sqsClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
77+
QueueUrl: q.QueueUrl,
78+
AttributeNames: []types.QueueAttributeName{
79+
types.QueueAttributeNameApproximateNumberOfMessages,
80+
},
6981
})
7082
if err != nil {
7183
return -1, err
7284
}
7385

7486
// Queue size
7587
sizeStr, ok := attr.Attributes[sqsAttributeOfQueueSize]
76-
if !ok || sizeStr == nil {
88+
if !ok {
7789
return -1, errors.New("attribute not found")
7890
}
79-
size, err := strconv.Atoi(*sizeStr)
91+
size, err := strconv.Atoi(sizeStr)
8092
if err != nil {
8193
return -1, err
8294
}
8395

8496
return size, nil
8597
}
8698

87-
func run(args []string) *checkers.Checker {
99+
func run(ctx context.Context, args []string) *checkers.Checker {
88100
_, err := flags.ParseArgs(&opts, args)
89101
if err != nil {
90102
os.Exit(1)
91103
}
92104

93-
size, err := getSqsQueueSize(opts.Region, opts.AccessKeyID, opts.SecretAccessKey, opts.QueueName)
105+
size, err := getSqsQueueSize(ctx, opts.Region, opts.AccessKeyID, opts.SecretAccessKey, opts.QueueName)
94106
if err != nil {
95107
return checkers.NewChecker(checkers.UNKNOWN, err.Error())
96108
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package checkawssqsqueuesize
2+
3+
import (
4+
"syscall"
5+
)
6+
7+
func init() {
8+
defaultSignal = syscall.SIGTERM
9+
}

0 commit comments

Comments
 (0)