Skip to content

Commit 2a88e41

Browse files
committed
feat(nrawssdk): Convert AccessKeyId to AccountID
1 parent 14bcbb4 commit 2a88e41

File tree

5 files changed

+188
-8
lines changed

5 files changed

+188
-8
lines changed

v3/integrations/nrawssdk-v2/example/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,17 @@ func main() {
3939
txn := app.StartTransaction("My sample transaction")
4040

4141
ctx := context.Background()
42+
4243
awsConfig, err := config.LoadDefaultConfig(ctx, func(awsConfig *config.LoadOptions) error {
4344
// Instrument all new AWS clients with New Relic
44-
nrawssdk.AppendMiddlewares(&awsConfig.APIOptions, nil)
45+
4546
return nil
4647
})
48+
creds, err := awsConfig.Credentials.Retrieve(ctx)
49+
if err != nil {
50+
log.Println("Warning couldn't get flags")
51+
}
52+
nrawssdk.AppendMiddlewares(&awsConfig.APIOptions, nil, creds)
4753
if err != nil {
4854
log.Fatal(err)
4955
}

v3/integrations/nrawssdk-v2/nrawssdk.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ package nrawssdk
2828

2929
import (
3030
"context"
31+
"fmt"
3132
"net/url"
3233
"strconv"
3334
"strings"
@@ -39,12 +40,14 @@ import (
3940
"github.com/aws/smithy-go/middleware"
4041
smithymiddle "github.com/aws/smithy-go/middleware"
4142
smithyhttp "github.com/aws/smithy-go/transport/http"
43+
"github.com/newrelic/go-agent/v3/internal/awssupport"
4244
"github.com/newrelic/go-agent/v3/newrelic"
4345
"github.com/newrelic/go-agent/v3/newrelic/integrationsupport"
4446
)
4547

4648
type nrMiddleware struct {
47-
txn *newrelic.Transaction
49+
txn *newrelic.Transaction
50+
creds aws.Credentials
4851
}
4952

5053
type contextKey string
@@ -69,13 +72,19 @@ func (m nrMiddleware) deserializeMiddleware(stack *smithymiddle.Stack) error {
6972
}
7073

7174
smithyRequest := in.Request.(*smithyhttp.Request)
72-
7375
// The actual http.Request is inside the smithyhttp.Request
7476
httpRequest := smithyRequest.Request
7577
serviceName := awsmiddle.GetServiceID(ctx)
7678
operation := awsmiddle.GetOperationName(ctx)
7779
region := awsmiddle.GetRegion(ctx)
7880

81+
creds := awsmiddle.GetSigningCredentials(ctx)
82+
accountID, err := awssupport.AWSAccountIdFromAWSAccessKey(creds)
83+
if err != nil {
84+
accountID = ""
85+
fmt.Println(err.Error())
86+
}
87+
7988
var segment endable
8089

8190
if serviceName == "dynamodb" || serviceName == "DynamoDB" {
@@ -129,6 +138,7 @@ func (m nrMiddleware) deserializeMiddleware(stack *smithymiddle.Stack) error {
129138
integrationsupport.AddAgentSpanAttribute(txn, newrelic.AttributeAWSElastSearchDomainEndpoint, httpRequest.URL.String()) // this way I don't have to pull it out of context
130139
}
131140
// Set additional span attributes
141+
integrationsupport.AddAgentSpanAttribute(txn, newrelic.AttributeCloudAccountID, accountID) // setting account ID here, why do we only do this if it is an SQS service?
132142
integrationsupport.AddAgentSpanAttribute(txn,
133143
newrelic.AttributeResponseCode, strconv.Itoa(response.StatusCode))
134144
integrationsupport.AddAgentSpanAttribute(txn,
@@ -150,8 +160,8 @@ func (m nrMiddleware) serializeMiddleware(stack *middleware.Stack) error {
150160
return stack.Initialize.Add(middleware.InitializeMiddlewareFunc("NRSerializeMiddleware", func(
151161
ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) (
152162
out middleware.InitializeOutput, metadata middleware.Metadata, err error) {
153-
154163
serviceName := awsmiddle.GetServiceID(ctx)
164+
ctx = awsmiddle.SetSigningCredentials(ctx, m.creds)
155165
switch serviceName {
156166
case "dynamodb", "DynamoDB":
157167
ctx = context.WithValue(ctx, dynamodbInputKey, dynamoDBInputFromMiddlewareInput(in))
@@ -219,8 +229,8 @@ func (m nrMiddleware) serializeMiddleware(stack *middleware.Stack) error {
219229
// if err != nil {
220230
// log.Fatal(err)
221231
// }
222-
func AppendMiddlewares(apiOptions *[]func(*smithymiddle.Stack) error, txn *newrelic.Transaction) {
223-
m := nrMiddleware{txn: txn}
232+
func AppendMiddlewares(apiOptions *[]func(*smithymiddle.Stack) error, txn *newrelic.Transaction, creds aws.Credentials) {
233+
m := nrMiddleware{txn: txn, creds: creds}
224234
*apiOptions = append(*apiOptions, m.deserializeMiddleware)
225235
*apiOptions = append(*apiOptions, m.serializeMiddleware)
226236
}

v3/integrations/nrawssdk-v2/nrawssdk_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ var fakeCreds = func() interface{} {
6464

6565
func newConfig(ctx context.Context, txn *newrelic.Transaction) aws.Config {
6666
cfg, _ := config.LoadDefaultConfig(ctx, func(o *config.LoadOptions) error {
67-
AppendMiddlewares(&o.APIOptions, txn)
67+
AppendMiddlewares(&o.APIOptions, txn, aws.Credentials{})
6868
return nil
6969
})
7070
cfg.Credentials = fakeCreds.(aws.CredentialsProvider)

v3/internal/awssupport/awssupport.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ package awssupport
88

99
import (
1010
"context"
11-
"github.com/newrelic/go-agent/v3/newrelic/integrationsupport"
11+
"encoding/base32"
12+
"fmt"
1213
"net/http"
1314
"reflect"
1415

16+
"github.com/aws/aws-sdk-go-v2/aws"
17+
"github.com/newrelic/go-agent/v3/newrelic/integrationsupport"
18+
1519
newrelic "github.com/newrelic/go-agent/v3/newrelic"
1620
)
1721

@@ -114,3 +118,31 @@ func EndSegment(ctx context.Context, resp *http.Response) {
114118
segment.End()
115119
}
116120
}
121+
122+
func AWSAccountIdFromAWSAccessKey(creds aws.Credentials) (string, error) {
123+
if creds.AccountID != "" {
124+
return creds.AccountID, nil
125+
}
126+
if creds.AccessKeyID == "" {
127+
return "", fmt.Errorf("no access key id found")
128+
}
129+
if len(creds.AccessKeyID) < 16 {
130+
return "", fmt.Errorf("improper access key id format")
131+
}
132+
trimmedAccessKey := creds.AccessKeyID[4:]
133+
decoded, err := base32.StdEncoding.DecodeString(trimmedAccessKey)
134+
if err != nil {
135+
return "", fmt.Errorf("error decoding access keys")
136+
}
137+
var bigEndian uint64
138+
for i := 0; i < 6; i++ {
139+
bigEndian = bigEndian << 8 // shift 8 bits left. Most significant byte read in first (decoded[i])
140+
bigEndian |= uint64(decoded[i]) // apply OR for current byte
141+
}
142+
143+
mask := uint64(0x7fffffffff80)
144+
145+
num := (bigEndian & mask) >> 7 // apply mask and get rid of last 7 bytes from mask
146+
147+
return fmt.Sprintf("%d", num), nil
148+
}

v3/internal/awssupport/awssupport_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"net/http"
1111
"strings"
1212
"testing"
13+
14+
"github.com/aws/aws-sdk-go-v2/aws"
1315
)
1416

1517
func TestGetTableName(t *testing.T) {
@@ -84,3 +86,133 @@ func TestGetRequestID(t *testing.T) {
8486
}
8587
}
8688
}
89+
90+
func TestAWSAccountIdFromAWSAccessKey(t *testing.T) {
91+
tests := []struct {
92+
name string // description of this test case
93+
// Named input parameters for target function.
94+
creds aws.Credentials
95+
want string
96+
wantErr bool
97+
wantErrStr string // error message returned
98+
}{
99+
{
100+
name: "first test",
101+
creds: aws.Credentials{
102+
AccountID: "",
103+
AccessKeyID: "AKIASAWSR23456AWS357",
104+
},
105+
want: "138954266361",
106+
wantErr: false,
107+
},
108+
{
109+
name: "AccountID already exists and access key exists. Should return AccountID immediately",
110+
creds: aws.Credentials{
111+
AccountID: "123451234512",
112+
AccessKeyID: "ASKDHA123457AKJFHAKS",
113+
},
114+
want: "123451234512",
115+
wantErr: false,
116+
},
117+
{
118+
name: "AccountID already exists and access key exists with too short of length. Should return AccountID immediately",
119+
creds: aws.Credentials{
120+
AccountID: "123451234512",
121+
AccessKeyID: "a",
122+
},
123+
want: "123451234512",
124+
wantErr: false,
125+
},
126+
{
127+
name: "AccountID already exists and access key exists with improper format. Should return AccountID immediately",
128+
creds: aws.Credentials{
129+
AccountID: "123451234512",
130+
AccessKeyID: "a a a. ",
131+
},
132+
want: "123451234512",
133+
wantErr: false,
134+
},
135+
{
136+
name: "AccountID already exists and access key does not exist. Should return AccountID immediately",
137+
creds: aws.Credentials{
138+
AccountID: "123451234512",
139+
},
140+
want: "123451234512",
141+
wantErr: false,
142+
},
143+
{
144+
name: "AccountID does not exist and access key does not exist. Should return an error",
145+
creds: aws.Credentials{},
146+
want: "",
147+
wantErr: true,
148+
wantErrStr: "no access key id found",
149+
},
150+
{
151+
name: "AccountID does not exist and access key is in an improper format. Should return an error",
152+
creds: aws.Credentials{
153+
AccessKeyID: "123asdfas",
154+
},
155+
want: "",
156+
wantErr: true,
157+
wantErrStr: "improper access key id format",
158+
},
159+
{
160+
name: "AccountID does not exist and access key is in an improper format with only one character. Should return an error",
161+
creds: aws.Credentials{
162+
AccessKeyID: "a",
163+
},
164+
want: "",
165+
wantErr: true,
166+
wantErrStr: "improper access key id format",
167+
},
168+
{
169+
name: "AccountID does not exist and access key is in an improper format for decoding.",
170+
creds: aws.Credentials{
171+
AccessKeyID: "a a a. ",
172+
},
173+
want: "",
174+
wantErr: true,
175+
wantErrStr: "error decoding access keys",
176+
},
177+
{
178+
name: "AccountID does not exist and access key contains non base32 characters",
179+
creds: aws.Credentials{
180+
AccessKeyID: "AKIA1234567899876541",
181+
},
182+
want: "",
183+
wantErr: true,
184+
wantErrStr: "error decoding access keys",
185+
},
186+
{
187+
name: "AccountID does not exist and access key contains non base32 characters and is too short in length",
188+
creds: aws.Credentials{
189+
AccessKeyID: "AKIA1818",
190+
},
191+
want: "",
192+
wantErr: true,
193+
wantErrStr: "improper access key id format",
194+
},
195+
}
196+
for _, tt := range tests {
197+
t.Run(tt.name, func(t *testing.T) {
198+
got, gotErr := AWSAccountIdFromAWSAccessKey(tt.creds)
199+
if gotErr != nil {
200+
if !tt.wantErr {
201+
t.Errorf("AWSAccountIdFromAWSAccessKey() failed: %v", gotErr)
202+
} else {
203+
if tt.wantErrStr != gotErr.Error() {
204+
t.Errorf("AWSAccountIdFromAWSAccessKey() error = %v, want %v", gotErr.Error(), tt.wantErrStr)
205+
}
206+
}
207+
return
208+
}
209+
if tt.wantErr {
210+
t.Fatal("AWSAccountIdFromAWSAccessKey() succeeded unexpectedly")
211+
}
212+
// TODO: update the condition below to compare got with tt.want.
213+
if tt.want != got {
214+
t.Errorf("AWSAccountIdFromAWSAccessKey() = %v, want %v", got, tt.want)
215+
}
216+
})
217+
}
218+
}

0 commit comments

Comments
 (0)