@@ -2,35 +2,71 @@ package aws_encryption_sdk
22
33import (
44 "bytes"
5+ "context"
56 "crypto/aes"
67 "crypto/cipher"
78 "encoding/binary"
89 "errors"
910 "strings"
1011
11- "github.com/aws/aws-sdk-go/aws"
12- "github.com/aws/aws-sdk-go/aws/credentials/stscreds "
13- "github.com/aws/aws-sdk-go/aws/session "
14- "github.com/aws/aws-sdk-go/service/kms "
12+ "github.com/aws/aws-sdk-go-v2 /aws"
13+ "github.com/aws/aws-sdk-go-v2/config "
14+ "github.com/aws/aws-sdk-go-v2/service/kms "
15+ "github.com/aws/aws-sdk-go-v2 /service/sts "
1516 "golang.org/x/crypto/hkdf"
1617)
1718
1819type KmsHelper struct {
19- client * kms.KMS
20+ client * kms.Client
2021}
2122
2223func NewKmsHelper (region string , assumedRole string ) * KmsHelper {
23- k := & KmsHelper {}
24- // Set up AWS KMS session
25- conf := aws .NewConfig ().WithRegion (region )
26- sess := session .Must (session .NewSession (conf ))
24+ ctx := context .Background ()
25+ var cfg aws.Config
26+ var err error
2727 if assumedRole != "" {
28- creds := stscreds .NewCredentials (sess , assumedRole )
29- k .client = kms .New (sess , & aws.Config {Credentials : creds })
28+ // Load default config
29+ cfg , err = config .LoadDefaultConfig (ctx , config .WithRegion (region ))
30+ if err != nil {
31+ panic (err )
32+ }
33+ // Assume role
34+ stsClient := sts .NewFromConfig (cfg )
35+ resp , err := stsClient .AssumeRole (ctx , & sts.AssumeRoleInput {
36+ RoleArn : aws .String (assumedRole ),
37+ RoleSessionName : aws .String ("decrypt-and-start-session" ),
38+ })
39+ if err != nil {
40+ panic (err )
41+ }
42+
43+ // Get a new config with the assumed role credentials
44+ var optFns []func (* config.LoadOptions ) error
45+ optFns = append (optFns , config .WithRegion (region ))
46+ optFns = append (optFns , config .WithCredentialsProvider (
47+ aws .CredentialsProviderFunc (func (ctx context.Context ) (aws.Credentials , error ) {
48+ return aws.Credentials {
49+ AccessKeyID : * resp .Credentials .AccessKeyId ,
50+ SecretAccessKey : * resp .Credentials .SecretAccessKey ,
51+ SessionToken : * resp .Credentials .SessionToken ,
52+ CanExpire : true ,
53+ Expires : * resp .Credentials .Expiration ,
54+ }, nil
55+ }),
56+ ))
57+
58+ newCfg , err := config .LoadDefaultConfig (ctx , optFns ... )
59+ if err != nil {
60+ panic (err )
61+ }
62+ cfg = newCfg
3063 } else {
31- k .client = kms .New (sess )
64+ cfg , err = config .LoadDefaultConfig (ctx , config .WithRegion (region ))
65+ if err != nil {
66+ panic (err )
67+ }
3268 }
33- return k
69+ return & KmsHelper { client : kms . NewFromConfig ( cfg )}
3470}
3571
3672// Decrypt encrypted data keys
@@ -88,17 +124,15 @@ func (k *KmsHelper) buildContentAAD(m *Message, f *Frame) ([]byte, error) {
88124
89125// Decrypt using KMS
90126func (k * KmsHelper ) kmsDecrypt (data []byte , m * Message ) ([]byte , error ) {
91- input := & kms.DecryptInput {
127+ ctx := context .Background ()
128+ in := & kms.DecryptInput {
92129 CiphertextBlob : data ,
93130 }
94- if m != nil {
95- context := make (map [string ]* string )
96- for key , value := range m .EncContext {
97- context [key ] = & value
98- }
99- input .EncryptionContext = context
131+ if m != nil && len (m .EncContext ) > 0 {
132+ in .EncryptionContext = m .EncContext
100133 }
101- result , err := k .client .Decrypt (input )
134+
135+ result , err := k .client .Decrypt (ctx , in )
102136 if err != nil {
103137 return nil , err
104138 }
@@ -114,7 +148,7 @@ func (k *KmsHelper) Decrypt(data []byte) ([]byte, error) {
114148 // Try simple KMS decryption first
115149 if plaintext , err = k .kmsDecrypt (data , nil ); err == nil {
116150 return plaintext , nil
117- } else if strings .HasPrefix (err .Error (), kms . ErrCodeInvalidCiphertextException ) {
151+ } else if strings .Contains (err .Error (), "InvalidCiphertextException" ) {
118152 // Do nothing for an InvalidCiphertextException error
119153 } else {
120154 // Unknown error
0 commit comments