Skip to content

Commit c73e359

Browse files
authored
acquisition: set S3 anon creds during tests, prevent deadlock (#3961)
1 parent 8cac689 commit c73e359

File tree

7 files changed

+108
-65
lines changed

7 files changed

+108
-65
lines changed

.golangci.yml

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,12 @@ linters:
197197
block-size: 6
198198

199199
nolintlint:
200-
require-explanation: false # don't require an explanation for nolint directives
201-
require-specific: true # don't require nolint directives to be specific about which linter is being skipped
202-
allow-unused: false # report any unused nolint directives
200+
# don't require an explanation for nolint directives
201+
require-explanation: false
202+
# don't require nolint directives to be specific about which linter is being skipped
203+
require-specific: true
204+
# report any unused nolint directives
205+
allow-unused: false
203206

204207
revive:
205208
severity: error
@@ -443,12 +446,6 @@ linters:
443446
path: pkg/types/utils.go
444447
text: 'argument-limit: .*'
445448

446-
# need some cleanup first: to create db in memory and share the client, not the config
447-
- linters:
448-
- usetesting
449-
path: (.+)_test.go
450-
text: context.Background.*
451-
452449
- linters:
453450
- usetesting
454451
path: pkg/apiserver/(.+)_test.go
@@ -469,18 +466,27 @@ linters:
469466
path: pkg/acquisition/modules/s3/s3.go
470467
text: found a struct that contains a context.Context field
471468

469+
- linters:
470+
- contextcheck
471+
text: Function `Configure->newS3Client` should pass the context parameter
472+
473+
- linters:
474+
- contextcheck
475+
text: Function `ConfigureByDSN->newS3Client` should pass the context parameter
476+
472477
# migrate over time
473478

474479
- linters:
475480
- noctx
481+
path: pkg/parser/enrich_dns.go
476482
text: "net.LookupAddr must not be called"
477483

478484
- linters:
479485
- noctx
486+
path: pkg/exprhelpers/helpers.go
480487
text: "net.LookupHost must not be called"
481488

482489
paths:
483-
- pkg/metabase
484490
- third_party$
485491
- builtin$
486492
- examples$
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
//go:build !test
2-
31
package kinesisacquisition
42

53
import "github.com/aws/aws-sdk-go-v2/aws"
64

5+
var defaultCredsFunc = func() aws.CredentialsProvider {
6+
return nil
7+
}
8+
79
func defaultCreds() aws.CredentialsProvider {
8-
return nil
10+
return defaultCredsFunc()
911
}
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
//go:build test
2-
31
package kinesisacquisition
42

53
import "github.com/aws/aws-sdk-go-v2/aws"
64

7-
func defaultCreds() aws.CredentialsProvider {
8-
return aws.AnonymousCredentials{}
5+
//nolint:gochecknoinits
6+
func init() {
7+
defaultCredsFunc = func() aws.CredentialsProvider {
8+
return aws.AnonymousCredentials{}
9+
}
910
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package s3acquisition
2+
3+
import "github.com/aws/aws-sdk-go-v2/aws"
4+
5+
var defaultCredsFunc = func() aws.CredentialsProvider {
6+
return nil
7+
}
8+
9+
func defaultCreds() aws.CredentialsProvider {
10+
return defaultCredsFunc()
11+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package s3acquisition
2+
3+
import "github.com/aws/aws-sdk-go-v2/aws"
4+
5+
//nolint:gochecknoinits
6+
func init() {
7+
defaultCredsFunc = func() aws.CredentialsProvider {
8+
return aws.AnonymousCredentials{}
9+
}
10+
}

pkg/acquisition/modules/s3/s3.go

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ const (
118118
SQSFormatSNS = "sns"
119119
)
120120

121-
func (s *S3Source) newS3Client() error {
122-
if s.s3Client != nil {
123-
return nil
124-
}
125-
121+
func (s *S3Source) newS3Client() (*s3.Client, error) {
126122
var loadOpts []func(*config.LoadOptions) error
127123
if s.Config.AwsProfile != nil && *s.Config.AwsProfile != "" {
128124
loadOpts = append(loadOpts, config.WithSharedConfigProfile(*s.Config.AwsProfile))
@@ -134,11 +130,14 @@ func (s *S3Source) newS3Client() error {
134130
}
135131

136132
loadOpts = append(loadOpts, config.WithRegion(region))
137-
loadOpts = append(loadOpts, config.WithCredentialsProvider(aws.AnonymousCredentials{}))
138133

139-
cfg, err := config.LoadDefaultConfig(s.ctx, loadOpts...)
134+
if c := defaultCreds(); c != nil {
135+
loadOpts = append(loadOpts, config.WithCredentialsProvider(c))
136+
}
137+
138+
cfg, err := config.LoadDefaultConfig(context.TODO(), loadOpts...)
140139
if err != nil {
141-
return fmt.Errorf("failed to load aws config: %w", err)
140+
return nil, fmt.Errorf("failed to load aws config: %w", err)
142141
}
143142

144143
var clientOpts []func(*s3.Options)
@@ -148,16 +147,10 @@ func (s *S3Source) newS3Client() error {
148147
})
149148
}
150149

151-
s.s3Client = s3.NewFromConfig(cfg, clientOpts...)
152-
153-
return nil
150+
return s3.NewFromConfig(cfg, clientOpts...), nil
154151
}
155152

156-
func (s *S3Source) newSQSClient() error {
157-
if s.sqsClient != nil {
158-
return nil
159-
}
160-
153+
func (s *S3Source) newSQSClient() (*sqs.Client, error) {
161154
var loadOpts []func(*config.LoadOptions) error
162155
if s.Config.AwsProfile != nil && *s.Config.AwsProfile != "" {
163156
loadOpts = append(loadOpts, config.WithSharedConfigProfile(*s.Config.AwsProfile))
@@ -169,21 +162,22 @@ func (s *S3Source) newSQSClient() error {
169162
}
170163

171164
loadOpts = append(loadOpts, config.WithRegion(region))
172-
loadOpts = append(loadOpts, config.WithCredentialsProvider(aws.AnonymousCredentials{}))
173165

174-
cfg, err := config.LoadDefaultConfig(s.ctx, loadOpts...)
166+
if c := defaultCreds(); c != nil {
167+
loadOpts = append(loadOpts, config.WithCredentialsProvider(c))
168+
}
169+
170+
cfg, err := config.LoadDefaultConfig(context.TODO(), loadOpts...)
175171
if err != nil {
176-
return fmt.Errorf("failed to load aws config: %w", err)
172+
return nil, fmt.Errorf("failed to load aws config: %w", err)
177173
}
178174

179175
var clientOpts []func(*sqs.Options)
180176
if s.Config.AwsEndpoint != "" {
181177
clientOpts = append(clientOpts, func(o *sqs.Options) { o.BaseEndpoint = aws.String(s.Config.AwsEndpoint) })
182178
}
183179

184-
s.sqsClient = sqs.NewFromConfig(cfg, clientOpts...)
185-
186-
return nil
180+
return sqs.NewFromConfig(cfg, clientOpts...), nil
187181
}
188182

189183
func (s *S3Source) readManager() {
@@ -207,7 +201,7 @@ func (s *S3Source) readManager() {
207201

208202
func (s *S3Source) getBucketContent() ([]s3types.Object, error) {
209203
logger := s.logger.WithField("method", "getBucketContent")
210-
logger.Debugf("Getting bucket content for %s", s.Config.BucketName)
204+
logger.Debugf("Getting bucket content")
211205

212206
bucketObjects := make([]s3types.Object, 0)
213207

@@ -274,10 +268,17 @@ func (s *S3Source) listPoll() error {
274268

275269
logger.Debugf("Found new object %s", *bucketObjects[i].Key)
276270

277-
s.readerChan <- S3Object{
271+
obj := S3Object{
278272
Bucket: s.Config.BucketName,
279273
Key: *bucketObjects[i].Key,
280274
}
275+
276+
select {
277+
case s.readerChan <- obj:
278+
case <-s.t.Dying():
279+
logger.Debug("tomb is dying, dropping object send")
280+
return nil
281+
}
281282
}
282283

283284
if newObject {
@@ -391,6 +392,9 @@ func (s *S3Source) sqsPoll() error {
391392
WaitTimeSeconds: 20, // Probably no need to make it configurable ?
392393
})
393394
if err != nil {
395+
if errors.Is(err, context.Canceled) {
396+
return nil
397+
}
394398
logger.Errorf("Error while polling SQS: %s", err)
395399
continue
396400
}
@@ -421,7 +425,13 @@ func (s *S3Source) sqsPoll() error {
421425

422426
logger.Debugf("Received SQS message for object %s/%s", bucket, key)
423427

424-
s.readerChan <- S3Object{Key: key, Bucket: bucket}
428+
// don't block if readManager has quit
429+
select {
430+
case s.readerChan <- S3Object{Key: key, Bucket: bucket}:
431+
case <-s.t.Dying():
432+
logger.Debug("tomb is dying, dropping object send")
433+
return nil
434+
}
425435

426436
_, err = s.sqsClient.DeleteMessage(s.ctx,
427437
&sqs.DeleteMessageInput{
@@ -516,7 +526,13 @@ func (s *S3Source) readFile(bucket string, key string) error {
516526
evt := types.MakeEvent(s.Config.UseTimeMachine, types.LOG, true)
517527
evt.Line = l
518528

519-
s.out <- evt
529+
// don't block in shutdown
530+
select {
531+
case s.out <-evt:
532+
case <-s.t.Dying():
533+
s.logger.Infof("tomb is dying, dropping event for %s/%s", bucket, key)
534+
return nil
535+
}
520536
}
521537
}
522538

@@ -615,16 +631,20 @@ func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry, metricsLevel
615631
s.logger.Warning("Polling method is set to list. This is not recommended as it will not scale well. Consider using SQS instead.")
616632
}
617633

618-
err = s.newS3Client()
634+
client, err := s.newS3Client()
619635
if err != nil {
620636
return err
621637
}
622638

639+
s.s3Client = client
640+
623641
if s.Config.PollingMethod == PollMethodSQS {
624-
err = s.newSQSClient()
642+
sqsClient, err := s.newSQSClient()
625643
if err != nil {
626644
return err
627645
}
646+
647+
s.sqsClient = sqsClient
628648
}
629649

630650
return nil
@@ -707,11 +727,13 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger *
707727
return fmt.Errorf("invalid DSN %s for S3 source", dsn)
708728
}
709729

710-
err := s.newS3Client()
730+
client, err := s.newS3Client()
711731
if err != nil {
712732
return err
713733
}
714734

735+
s.s3Client = client
736+
715737
return nil
716738
}
717739

@@ -768,21 +790,11 @@ func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Even
768790

769791
if s.Config.PollingMethod == PollMethodSQS {
770792
t.Go(func() error {
771-
err := s.sqsPoll()
772-
if err != nil {
773-
return err
774-
}
775-
776-
return nil
793+
return s.sqsPoll()
777794
})
778795
} else {
779796
t.Go(func() error {
780-
err := s.listPoll()
781-
if err != nil {
782-
return err
783-
}
784-
785-
return nil
797+
return s.listPoll()
786798
})
787799
}
788800

pkg/acquisition/modules/s3/s3_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,11 @@ func TestDSNAcquis(t *testing.T) {
268268
linesRead := 0
269269
f := S3Source{}
270270
logger := log.NewEntry(log.New())
271-
f.s3Client = mockS3Client{}
272271
err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger, "")
273272
require.NoError(t, err)
273+
274+
f.s3Client = mockS3Client{}
275+
274276
assert.Equal(t, test.expectedBucketName, f.Config.BucketName)
275277
assert.Equal(t, test.expectedPrefix, f.Config.Prefix)
276278

@@ -338,12 +340,10 @@ prefix: foo/
338340
logger := log.NewEntry(log.New())
339341
logger.Logger.SetLevel(log.TraceLevel)
340342

341-
f.s3Client = mockS3Client{}
342-
343343
err := f.Configure([]byte(test.config), logger, metrics.AcquisitionMetricsLevelNone)
344-
if err != nil {
345-
t.Fatalf("unexpected error: %s", err.Error())
346-
}
344+
require.NoError(t, err)
345+
346+
f.s3Client = mockS3Client{}
347347

348348
if f.Config.PollingMethod != PollMethodList {
349349
t.Fatalf("expected list polling, got %s", f.Config.PollingMethod)
@@ -423,10 +423,11 @@ sqs_name: test
423423
linesRead := 0
424424
f := S3Source{}
425425
logger := log.NewEntry(log.New())
426-
f.s3Client = mockS3Client{}
427426
err := f.Configure([]byte(test.config), logger, metrics.AcquisitionMetricsLevelNone)
428427
require.NoError(t, err)
429428

429+
f.s3Client = mockS3Client{}
430+
430431
if f.Config.PollingMethod != PollMethodSQS {
431432
t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod)
432433
}

0 commit comments

Comments
 (0)