|
| 1 | +package auth |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "net/http" |
| 7 | + "strings" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/aws/aws-sdk-go-v2/aws" |
| 11 | + "github.com/aws/aws-sdk-go-v2/aws/signer/v4" |
| 12 | +) |
| 13 | + |
| 14 | +const ( |
| 15 | + signingID = "rds-db" |
| 16 | + emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" |
| 17 | +) |
| 18 | + |
| 19 | +// BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken |
| 20 | +type BuildAuthTokenOptions struct{} |
| 21 | + |
| 22 | +// BuildAuthToken will return an authorization token used as the password for a DB |
| 23 | +// connection. |
| 24 | +// |
| 25 | +// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port> |
| 26 | +// * region - Region is the location of where the DB is |
| 27 | +// * dbUser - User account within the database to sign in with |
| 28 | +// * creds - Credentials to be signed with |
| 29 | +// |
| 30 | +// The following example shows how to use BuildAuthToken to create an authentication |
| 31 | +// token for connecting to a MySQL database in RDS. |
| 32 | +// |
| 33 | +// authToken, err := BuildAuthToken(dbEndpoint, awsRegion, dbUser, awsCreds) |
| 34 | +// |
| 35 | +// // Create the MySQL DNS string for the DB connection |
| 36 | +// // user:password@protocol(endpoint)/dbname?<params> |
| 37 | +// connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds", |
| 38 | +// dbUser, authToken, dbEndpoint, dbName, |
| 39 | +// ) |
| 40 | +// |
| 41 | +// // Use db to perform SQL operations on database |
| 42 | +// db, err := sql.Open("mysql", connectStr) |
| 43 | +// |
| 44 | +// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html |
| 45 | +// for more information on using IAM database authentication with RDS. |
| 46 | +func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) { |
| 47 | + o := BuildAuthTokenOptions{} |
| 48 | + |
| 49 | + for _, fn := range optFns { |
| 50 | + fn(&o) |
| 51 | + } |
| 52 | + |
| 53 | + if creds == nil { |
| 54 | + return "", fmt.Errorf("credetials provider must not ne nil") |
| 55 | + } |
| 56 | + |
| 57 | + // the scheme is arbitrary and is only needed because validation of the URL requires one. |
| 58 | + if !(strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://")) { |
| 59 | + endpoint = "https://" + endpoint |
| 60 | + } |
| 61 | + |
| 62 | + req, err := http.NewRequest("GET", endpoint, nil) |
| 63 | + if err != nil { |
| 64 | + return "", err |
| 65 | + } |
| 66 | + values := req.URL.Query() |
| 67 | + values.Set("Action", "connect") |
| 68 | + values.Set("DBUser", dbUser) |
| 69 | + req.URL.RawQuery = values.Encode() |
| 70 | + |
| 71 | + signer := v4.NewSigner() |
| 72 | + |
| 73 | + credentials, err := creds.Retrieve(ctx) |
| 74 | + if err != nil { |
| 75 | + return "", err |
| 76 | + } |
| 77 | + |
| 78 | + // Expire Time: 15 minute |
| 79 | + query := req.URL.Query() |
| 80 | + query.Set("X-Amz-Expires", "900") |
| 81 | + req.URL.RawQuery = query.Encode() |
| 82 | + |
| 83 | + signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC()) |
| 84 | + if err != nil { |
| 85 | + return "", err |
| 86 | + } |
| 87 | + |
| 88 | + url := signedURI |
| 89 | + if strings.HasPrefix(url, "http://") { |
| 90 | + url = url[len("http://"):] |
| 91 | + } else if strings.HasPrefix(url, "https://") { |
| 92 | + url = url[len("https://"):] |
| 93 | + } |
| 94 | + |
| 95 | + return url, nil |
| 96 | +} |
0 commit comments