11package mongoaws
22
33import (
4+ "bytes"
45 "context"
6+ "crypto/rand"
7+ "encoding/base64"
58 "errors"
9+ "fmt"
610 "net/http"
11+ "strings"
12+ "time"
713
14+ "github.com/aws/aws-sdk-go-v2/aws"
15+ awsv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
16+ "github.com/aws/aws-sdk-go-v2/config"
17+ "go.mongodb.org/mongo-driver/v2/bson"
18+ "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
819 "go.mongodb.org/mongo-driver/v2/x/mongo/driver"
920 "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth"
1021)
1122
12- const sourceExternal = "$external"
23+ const (
24+ sourceExternal = "$external"
25+ responceNonceLength = 64
26+ amzDateFormat = "20060102T150405Z"
27+ awsSessionToken = "AWS_SESSION_TOKEN"
28+ defaultRegion = "us-east-1"
29+ maxHostLength = 255
30+ )
1331
1432// Authenticator is an authenticator that uses the AWS SDK rather than the
1533// lightweight AWS package used internally by the driver.
16- type Authenticator struct {}
34+ type Authenticator struct {
35+ userCred * auth.Cred // MongoDB TLS credentials with AWS keys
36+ awsCfg aws.Config // AWS SDK config
37+ signer * awsv4.Signer // SigV4 signer
38+ }
1739
1840var _ driver.Authenticator = (* Authenticator )(nil )
1941
2042// NewAuthenticator creates a new AWS SDK authenticator. It loads the AWS
2143// SDK config (honoring AWS_STS_REGIONAL_ENDPOINTS & AWS_REGION) and returns an
2244// Authenticator that uses it.
23- func NewAuthenticator (* auth.Cred , * http.Client ) (driver.Authenticator , error ) {
24- return & Authenticator {}, nil
45+ func NewAuthenticator (cred * auth.Cred , _ * http.Client ) (driver.Authenticator , error ) {
46+ // Load AWS SDK config from environment variables / credentials files.
47+ awsCfg , err := config .LoadDefaultConfig (context .Background (), config .WithRegion ("us-east-1" ))
48+ if err != nil {
49+ return nil , fmt .Errorf ("failed to load AWS SDK config: %w" , err )
50+ }
51+
52+ return & Authenticator {
53+ userCred : cred ,
54+ awsCfg : awsCfg ,
55+ signer : awsv4 .NewSigner (),
56+ }, nil
2557}
2658
2759var _ auth.AuthenticatorFactory = NewAuthenticator
@@ -30,7 +62,11 @@ var _ auth.AuthenticatorFactory = NewAuthenticator
3062// uses the AWS SDK for singing.
3163func (a * Authenticator ) Auth (ctx context.Context , cfg * driver.AuthConfig ) error {
3264 // Build a SASL adapter that uses AWS SDK for signing.
33- adapter := & awsSdkSaslClient {}
65+ adapter := & awsSdkSaslClient {
66+ userCred : a .userCred ,
67+ awsCfg : a .awsCfg ,
68+ signer : a .signer ,
69+ }
3470
3571 return auth .ConductSaslConversation (ctx , cfg , sourceExternal , adapter )
3672}
@@ -40,25 +76,142 @@ func (a *Authenticator) Reauth(context.Context, *driver.AuthConfig) error {
4076 return errors .New ("AWS reauthentication not supported" )
4177}
4278
79+ type conversationState uint8
80+
81+ const (
82+ conversationStateStart conversationState = 1 // before sending anything.
83+ conversationStateServerFirst conversationState = 2 // after sending client-first, awaiting server reply.
84+ conversationStateDone conversationState = 3 // after sending client-final.
85+ )
86+
4387// awsSdkSaslClient is a SASL client that uses the AWS SDK for signing.
44- type awsSdkSaslClient struct {}
88+ type awsSdkSaslClient struct {
89+ state conversationState // handshake state machine
90+ nonce []byte // client nonce
91+ userCred * auth.Cred // MongoDB TLS credentials with AWS keys
92+ awsCfg aws.Config // AWS SDK config
93+ signer * awsv4.Signer // SigV4 signer
94+ }
4595
4696var _ auth.SaslClient = (* awsSdkSaslClient )(nil )
4797
4898// Start will create the client-first SASL message.
4999// { p: 110, r: <32-byte nonce>}; per the current Go Driver behavior.
50- func (a * awsSdkSaslClient ) Start () (string , []byte , error ) {
51- return "" , nil , nil
100+ func (client * awsSdkSaslClient ) Start () (string , []byte , error ) {
101+ client .state = conversationStateServerFirst
102+ client .nonce = make ([]byte , 32 )
103+ _ , _ = rand .Read (client .nonce )
104+
105+ idx , msg := bsoncore .AppendDocumentStart (nil )
106+ msg = bsoncore .AppendInt32Element (msg , "p" , 110 )
107+ msg = bsoncore .AppendBinaryElement (msg , "r" , 0x00 , client .nonce )
108+ msg , _ = bsoncore .AppendDocumentEnd (msg , idx )
109+
110+ return auth .MongoDBAWS , msg , nil
111+ }
112+
113+ func getRegion (host string ) (string , error ) {
114+ region := defaultRegion
115+
116+ if len (host ) == 0 {
117+ return "" , errors .New ("invalid STS host: empty" )
118+ }
119+ if len (host ) > maxHostLength {
120+ return "" , errors .New ("invalid STS host: too large" )
121+ }
122+ // The implicit region for sts.amazonaws.com is us-east-1
123+ if host == "sts.amazonaws.com" {
124+ return region , nil
125+ }
126+ if strings .HasPrefix (host , "." ) || strings .HasSuffix (host , "." ) || strings .Contains (host , ".." ) {
127+ return "" , errors .New ("invalid STS host: empty part" )
128+ }
129+
130+ // If the host has multiple parts, the second part is the region
131+ parts := strings .Split (host , "." )
132+ if len (parts ) >= 2 {
133+ region = parts [1 ]
134+ }
135+
136+ return region , nil
52137}
53138
54139// Next handles the server's "server-first" message, then builds and returns the
55140// "client-final" payload containing the SigV4-signed STS GetCallerIdentity
56141// request.
57- func (a * awsSdkSaslClient ) Next (ctx context.Context , challenge []byte ) ([]byte , error ) {
142+ func (client * awsSdkSaslClient ) Next (ctx context.Context , challenge []byte ) ([]byte , error ) {
143+ if client .state != conversationStateServerFirst {
144+ return nil , fmt .Errorf ("invalid state: %v" , client .state )
145+ }
146+ client .state = conversationStateDone
147+
148+ // Unmarhal the server's BSON: { s: <server nonce>, h: "<sts host>"}
149+ var sm struct {
150+ Nonce bson.Binary `bson:"s"`
151+ Host string `bson:"h"`
152+ }
153+
154+ if err := bson .Unmarshal (challenge , & sm ); err != nil {
155+ return nil , err
156+ }
157+
158+ // Check nonce prefix
159+ if sm .Nonce .Subtype != 0x00 {
160+ return nil , errors .New ("server reply contained unexpected binary subtype" )
161+ }
162+
163+ if len (sm .Nonce .Data ) != responceNonceLength {
164+ return nil , fmt .Errorf ("server reply nonce was not %v bytes" , responceNonceLength )
165+ }
166+
167+ if ! bytes .HasPrefix (sm .Nonce .Data , client .nonce ) {
168+ return nil , errors .New ("server nonce did not extend client nonce" )
169+ }
170+
171+ currentTime := time .Now ().UTC ()
172+ body := "Action=GetCallerIdentity&Version=2011-06-15"
173+
174+ // Create http.Request
175+ req , _ := http .NewRequest ("POST" , "/" , strings .NewReader (body ))
176+ req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded" )
177+ req .Header .Set ("Content-Length" , "43" )
178+ req .Host = sm .Host
179+ req .Header .Set ("X-Amz-Date" , currentTime .Format (amzDateFormat ))
180+
181+ // Include session token if present.
182+ if tok := client .userCred .Props [awsSessionToken ]; tok != "" {
183+ req .Header .Set ("X-Amz-Security-Token" , tok )
184+ }
185+
186+ req .Header .Set ("X-MongoDB-Server-Nonce" , base64 .StdEncoding .EncodeToString (sm .Nonce .Data ))
187+ req .Header .Set ("X-MongoDB-GS2-CB-Flag" , "n" )
188+
189+ // Retrieve AWS creds and sign the request using AWS SDK v4.
190+ creds , err := client .awsCfg .Credentials .Retrieve (ctx )
191+ if err != nil {
192+ return nil , fmt .Errorf ("failed to retrieve AWS credentials: %w" , err )
193+ }
194+
195+ // Create signer with credentials
196+ err = client .signer .SignHTTP (ctx , creds , req , body , "sts" , sm .Host , currentTime )
197+ if err != nil {
198+ return nil , fmt .Errorf ("failed to sign request: %w" , err )
199+ }
200+
201+ // create message
202+ // { a: Authorization, d: X-Amz-Date, t: X-Amz-Security-Token }
203+ idx , msg := bsoncore .AppendDocumentStart (nil )
204+ msg = bsoncore .AppendStringElement (msg , "a" , req .Header .Get ("Authorization" ))
205+ msg = bsoncore .AppendStringElement (msg , "d" , req .Header .Get ("X-Amz-Date" ))
206+ if tok := req .Header .Get ("X-Amz-Security-Token" ); tok != "" {
207+ msg = bsoncore .AppendStringElement (msg , "t" , tok )
208+ }
209+ msg , _ = bsoncore .AppendDocumentEnd (msg , idx )
210+
58211 return nil , nil
59212}
60213
61214// complete signals that the SASL conversation is done.
62- func (a * awsSdkSaslClient ) Completed () bool {
63- return false
215+ func (client * awsSdkSaslClient ) Completed () bool {
216+ return client . state == conversationStateDone
64217}
0 commit comments