@@ -28,6 +28,7 @@ package nrawssdk
2828
2929import (
3030 "context"
31+ "encoding/base32"
3132 "fmt"
3233 "net/url"
3334 "strconv"
@@ -40,16 +41,22 @@ import (
4041 "github.com/aws/smithy-go/middleware"
4142 smithymiddle "github.com/aws/smithy-go/middleware"
4243 smithyhttp "github.com/aws/smithy-go/transport/http"
43- "github.com/newrelic/go-agent/v3/internal/awssupport"
4444 "github.com/newrelic/go-agent/v3/newrelic"
4545 "github.com/newrelic/go-agent/v3/newrelic/integrationsupport"
4646)
4747
48+ type credentialsResolver interface {
49+ AWSAccountIdFromAWSAccessKey (creds aws.Credentials ) (string , error )
50+ }
51+
4852type nrMiddleware struct {
49- txn * newrelic.Transaction
50- creds aws.Credentials
53+ txn * newrelic.Transaction
54+ accountID string
55+ resolver credentialsResolver
5156}
5257
58+ type defaultResolver struct {}
59+
5360type contextKey string
5461
5562const (
@@ -77,14 +84,7 @@ func (m nrMiddleware) deserializeMiddleware(stack *smithymiddle.Stack) error {
7784 serviceName := awsmiddle .GetServiceID (ctx )
7885 operation := awsmiddle .GetOperationName (ctx )
7986 region := awsmiddle .GetRegion (ctx )
80-
81- creds := awsmiddle .GetSigningCredentials (ctx )
82- accountID , err := awssupport .AWSAccountIdFromAWSAccessKey (creds )
83- if err != nil {
84- accountID = ""
85- fmt .Println (err .Error ())
86- }
87-
87+ accountID := m .accountID
8888 var segment endable
8989
9090 if serviceName == "dynamodb" || serviceName == "DynamoDB" {
@@ -138,7 +138,9 @@ func (m nrMiddleware) deserializeMiddleware(stack *smithymiddle.Stack) error {
138138 integrationsupport .AddAgentSpanAttribute (txn , newrelic .AttributeAWSElastSearchDomainEndpoint , httpRequest .URL .String ()) // this way I don't have to pull it out of context
139139 }
140140 // Set additional span attributes
141+
141142 integrationsupport .AddAgentSpanAttribute (txn , newrelic .AttributeCloudAccountID , accountID ) // setting account ID here, why do we only do this if it is an SQS service?
143+
142144 integrationsupport .AddAgentSpanAttribute (txn ,
143145 newrelic .AttributeResponseCode , strconv .Itoa (response .StatusCode ))
144146 integrationsupport .AddAgentSpanAttribute (txn ,
@@ -161,7 +163,6 @@ func (m nrMiddleware) serializeMiddleware(stack *middleware.Stack) error {
161163 ctx context.Context , in middleware.InitializeInput , next middleware.InitializeHandler ) (
162164 out middleware.InitializeOutput , metadata middleware.Metadata , err error ) {
163165 serviceName := awsmiddle .GetServiceID (ctx )
164- ctx = awsmiddle .SetSigningCredentials (ctx , m .creds )
165166 switch serviceName {
166167 case "dynamodb" , "DynamoDB" :
167168 ctx = context .WithValue (ctx , dynamodbInputKey , dynamoDBInputFromMiddlewareInput (in ))
@@ -229,8 +230,30 @@ func (m nrMiddleware) serializeMiddleware(stack *middleware.Stack) error {
229230// if err != nil {
230231// log.Fatal(err)
231232// }
232- func AppendMiddlewares (apiOptions * []func (* smithymiddle.Stack ) error , txn * newrelic.Transaction , creds aws.Credentials ) {
233- m := nrMiddleware {txn : txn , creds : creds }
233+ func AppendMiddlewares (apiOptions * []func (* smithymiddle.Stack ) error , txn * newrelic.Transaction ) {
234+ m := nrMiddleware {txn : txn }
235+ * apiOptions = append (* apiOptions , m .deserializeMiddleware )
236+ * apiOptions = append (* apiOptions , m .serializeMiddleware )
237+ }
238+
239+ func NRAppendMiddlewares (apiOptions * []func (* smithymiddle.Stack ) error , ctx context.Context , awsConfig aws.Config ) {
240+ txn := newrelic .FromContext (ctx )
241+
242+ creds , err := awsConfig .Credentials .Retrieve (ctx )
243+ if err != nil {
244+ fmt .Println ("error: Couldn't get AWS Credentials" )
245+ }
246+
247+ cfg , ok := txn .Application ().Config ()
248+ m := nrMiddleware {txn : txn , resolver : & defaultResolver {}}
249+
250+ if ok {
251+ err := m .ResolveAWSCredentials (cfg , creds )
252+ if err != nil {
253+ fmt .Println ("error: Couldn't resolve AWS credentials" )
254+ }
255+ }
256+
234257 * apiOptions = append (* apiOptions , m .deserializeMiddleware )
235258 * apiOptions = append (* apiOptions , m .serializeMiddleware )
236259}
@@ -291,3 +314,55 @@ func dynamoDBInputFromMiddlewareInput(in middleware.InitializeInput) dynamodbInp
291314 return dynamodbInput {}
292315 }
293316}
317+
318+ func (m * nrMiddleware ) ResolveAWSCredentials (cfg newrelic.Config , creds aws.Credentials ) error {
319+
320+ if m .resolver == nil {
321+ m .resolver = & defaultResolver {}
322+ }
323+
324+ accountID , err := m .resolver .AWSAccountIdFromAWSAccessKey (creds )
325+ if err != nil {
326+ return err
327+ }
328+
329+ // Use resolved accountID if:
330+ // 1. No accountID is set in config (cfg.CloudAWS.AccountID is empty), OR
331+ // 2. Resolved accountID is different from config accountID
332+ if cfg .CloudAWS .AccountID == "" || (accountID != "" && accountID != cfg .CloudAWS .AccountID ) {
333+ m .accountID = accountID
334+ return nil
335+ }
336+
337+ // Otherwise use the config accountID
338+ m .accountID = cfg .CloudAWS .AccountID
339+ return nil
340+ }
341+
342+ func (m * defaultResolver ) AWSAccountIdFromAWSAccessKey (creds aws.Credentials ) (string , error ) {
343+ if creds .AccountID != "" {
344+ return creds .AccountID , nil
345+ }
346+ if creds .AccessKeyID == "" {
347+ return "" , fmt .Errorf ("no access key id found" )
348+ }
349+ if len (creds .AccessKeyID ) < 16 {
350+ return "" , fmt .Errorf ("improper access key id format" )
351+ }
352+ trimmedAccessKey := creds .AccessKeyID [4 :]
353+ decoded , err := base32 .StdEncoding .DecodeString (trimmedAccessKey )
354+ if err != nil {
355+ return "" , fmt .Errorf ("error decoding access keys" )
356+ }
357+ var bigEndian uint64
358+ for i := 0 ; i < 6 ; i ++ {
359+ bigEndian = bigEndian << 8 // shift 8 bits left. Most significant byte read in first (decoded[i])
360+ bigEndian |= uint64 (decoded [i ]) // apply OR for current byte
361+ }
362+
363+ mask := uint64 (0x7fffffffff80 )
364+
365+ num := (bigEndian & mask ) >> 7 // apply mask and get rid of last 7 bytes from mask
366+
367+ return fmt .Sprintf ("%d" , num ), nil
368+ }
0 commit comments