Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions sdk/src/Services/DSQL/Custom/Util/DSQLAuthTokenGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public static class DSQLAuthTokenGenerator
private const string XAmzExpires = "X-Amz-Expires";
private const string XAmzSecurityToken = "X-Amz-Security-Token";
private static readonly TimeSpan FifteenMinutes = TimeSpan.FromMinutes(15);
private static readonly TimeSpan MaxExpiresIn = TimeSpan.FromDays(7);

/// <summary>
/// AWS4PreSignedUrlSigner is built around operation request objects.
Expand Down Expand Up @@ -116,7 +117,24 @@ public static string GenerateDbConnectAuthToken(AWSCredentials credentials, Regi
throw new ArgumentNullException("credentials");

var immutableCredentials = credentials.GetCredentials();
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue, FifteenMinutes);
}

/// <summary>
/// Generate a token for IAM authentication to a DSQL database cluster for the DbConnect action.
/// </summary>
/// <param name="credentials">The credentials for the token.</param>
/// <param name="region">The region of the DSQL database.</param>
/// <param name="hostname">Hostname of the DSQL database.</param>
/// <param name="expiresIn">The token expiry duration. If not specified on other overloads, defaults to 15 minutes.</param>
/// <returns></returns>
public static string GenerateDbConnectAuthToken(AWSCredentials credentials, RegionEndpoint region, string hostname, TimeSpan expiresIn)
{
if (credentials == null)
throw new ArgumentNullException("credentials");

var immutableCredentials = credentials.GetCredentials();
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue, expiresIn);
}

/// <summary>
Expand Down Expand Up @@ -182,7 +200,24 @@ public static async System.Threading.Tasks.Task<string> GenerateDbConnectAuthTok
throw new ArgumentNullException("credentials");

var immutableCredentials = await credentials.GetCredentialsAsync().ConfigureAwait(false);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue, FifteenMinutes);
}

/// <summary>
/// Generate a token for IAM authentication to a DSQL database cluster for the DbConnect action.
/// </summary>
/// <param name="credentials">The credentials for the token.</param>
/// <param name="region">The region of the DSQL database.</param>
/// <param name="hostname">Hostname of the DSQL database.</param>
/// <param name="expiresIn">The token expiry duration. If not specified on other overloads, defaults to 15 minutes.</param>
/// <returns></returns>
public static async System.Threading.Tasks.Task<string> GenerateDbConnectAuthTokenAsync(AWSCredentials credentials, RegionEndpoint region, string hostname, TimeSpan expiresIn)
{
if (credentials == null)
throw new ArgumentNullException("credentials");

var immutableCredentials = await credentials.GetCredentialsAsync().ConfigureAwait(false);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectActionValue, expiresIn);
}

/// <summary>
Expand Down Expand Up @@ -248,7 +283,24 @@ public static string GenerateDbConnectAdminAuthToken(AWSCredentials credentials,
throw new ArgumentNullException("credentials");

var immutableCredentials = credentials.GetCredentials();
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue, FifteenMinutes);
}

/// <summary>
/// Generate a token for IAM authentication to a DSQL database cluster for the DbConnectAdmin action.
/// </summary>
/// <param name="credentials">The credentials for the token.</param>
/// <param name="region">The region of the DSQL database.</param>
/// <param name="hostname">Hostname of the DSQL database.</param>
/// <param name="expiresIn">The token expiry duration. If not specified on other overloads, defaults to 15 minutes.</param>
/// <returns></returns>
public static string GenerateDbConnectAdminAuthToken(AWSCredentials credentials, RegionEndpoint region, string hostname, TimeSpan expiresIn)
{
if (credentials == null)
throw new ArgumentNullException("credentials");

var immutableCredentials = credentials.GetCredentials();
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue, expiresIn);
}

/// <summary>
Expand Down Expand Up @@ -314,10 +366,27 @@ public static async System.Threading.Tasks.Task<string> GenerateDbConnectAdminAu
throw new ArgumentNullException("credentials");

var immutableCredentials = await credentials.GetCredentialsAsync().ConfigureAwait(false);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue, FifteenMinutes);
}

/// <summary>
/// Generate a token for IAM authentication to a DSQL database cluster for the DbConnectAdmin action.
/// </summary>
/// <param name="credentials">The credentials for the token.</param>
/// <param name="region">The region of the DSQL database.</param>
/// <param name="hostname">Hostname of the DSQL database.</param>
/// <param name="expiresIn">The token expiry duration. If not specified on other overloads, defaults to 15 minutes.</param>
/// <returns></returns>
public static async System.Threading.Tasks.Task<string> GenerateDbConnectAdminAuthTokenAsync(AWSCredentials credentials, RegionEndpoint region, string hostname, TimeSpan expiresIn)
{
if (credentials == null)
throw new ArgumentNullException("credentials");

var immutableCredentials = await credentials.GetCredentialsAsync().ConfigureAwait(false);
return GenerateAuthToken(immutableCredentials, region, hostname, DBConnectAdminActionValue, expiresIn);
}

private static string GenerateAuthToken(ImmutableCredentials immutableCredentials, RegionEndpoint region, string hostname, string actionValue)
private static string GenerateAuthToken(ImmutableCredentials immutableCredentials, RegionEndpoint region, string hostname, string actionValue, TimeSpan expiresIn)
{
if (immutableCredentials == null)
throw new ArgumentNullException("immutableCredentials");
Expand All @@ -329,12 +398,15 @@ private static string GenerateAuthToken(ImmutableCredentials immutableCredential
if (string.IsNullOrEmpty(hostname))
throw new ArgumentException("Hostname must not be null or empty.");

if (expiresIn <= TimeSpan.Zero || expiresIn > MaxExpiresIn)
throw new ArgumentOutOfRangeException("expiresIn", "ExpiresIn must be between 0 (exclusive) and 7 days (inclusive).");

GenerateDSQLAuthTokenRequest authTokenRequest = new GenerateDSQLAuthTokenRequest();
IRequest request = new DefaultRequest(authTokenRequest, DSQLServiceName);

request.UseQueryString = true;
request.HttpMethod = HTTPGet;
request.Parameters.Add(XAmzExpires, FifteenMinutes.TotalSeconds.ToString(CultureInfo.InvariantCulture));
request.Parameters.Add(XAmzExpires, expiresIn.TotalSeconds.ToString(CultureInfo.InvariantCulture));
request.Parameters.Add(ActionKey, actionValue);
request.Endpoint = new UriBuilder(HTTPS, hostname).Uri;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,54 @@ public void GenerateDbConnectAuthTokenEmptyHostname()
}, typeof(ArgumentException));
}

[TestMethod]
[TestCategory("DSQL")]
public void GenerateDbConnectAuthTokenCustomExpiresIn()
{
AssertAuthToken(DSQLAuthTokenGenerator.GenerateDbConnectAuthToken(BasicCredentials,
AWSRegion, DBCluster, TimeSpan.FromSeconds(450)), AccessKey, AWSRegion, DBConnectActionValue, false, 450);
}

[TestMethod]
[TestCategory("DSQL")]
public void GenerateDbConnectAuthTokenZeroExpiresIn()
{
AssertExtensions.ExpectException(() =>
{
DSQLAuthTokenGenerator.GenerateDbConnectAuthToken(BasicCredentials, AWSRegion, DBCluster, TimeSpan.Zero);
}, typeof(ArgumentOutOfRangeException));
}

[TestMethod]
[TestCategory("DSQL")]
public void GenerateDbConnectAuthTokenNegativeExpiresIn()
{
AssertExtensions.ExpectException(() =>
{
DSQLAuthTokenGenerator.GenerateDbConnectAuthToken(BasicCredentials, AWSRegion, DBCluster, TimeSpan.FromSeconds(-1));
}, typeof(ArgumentOutOfRangeException));
}

[TestMethod]
[TestCategory("DSQL")]
public void GenerateDbConnectAuthTokenExpiresInExceeds7Days()
{
AssertExtensions.ExpectException(() =>
{
DSQLAuthTokenGenerator.GenerateDbConnectAuthToken(BasicCredentials, AWSRegion, DBCluster, TimeSpan.FromDays(8));
}, typeof(ArgumentOutOfRangeException));
}

#if ASYNC_AWAIT
[TestMethod]
[TestCategory("DSQL")]
public async System.Threading.Tasks.Task GenerateDbConnectAuthTokenCustomExpiresInAsync()
{
AssertAuthToken(await DSQLAuthTokenGenerator.GenerateDbConnectAuthTokenAsync(BasicCredentials,
AWSRegion, DBCluster, TimeSpan.FromSeconds(450)), AccessKey, AWSRegion, DBConnectActionValue, false, 450);
}
#endif

// DbConnectAdmin

#if ASYNC_AWAIT
Expand Down Expand Up @@ -275,12 +323,35 @@ public void GenerateDbConnectAdminAuthTokenEmptyHostname()
}, typeof(ArgumentException));
}

[TestMethod]
[TestCategory("DSQL")]
public void GenerateDbConnectAdminAuthTokenCustomExpiresIn()
{
AssertAuthToken(DSQLAuthTokenGenerator.GenerateDbConnectAdminAuthToken(BasicCredentials,
AWSRegion, DBCluster, TimeSpan.FromSeconds(450)), AccessKey, AWSRegion, DBConnectAdminActionValue, false, 450);
}

#if ASYNC_AWAIT
[TestMethod]
[TestCategory("DSQL")]
public async System.Threading.Tasks.Task GenerateDbConnectAdminAuthTokenCustomExpiresInAsync()
{
AssertAuthToken(await DSQLAuthTokenGenerator.GenerateDbConnectAdminAuthTokenAsync(BasicCredentials,
AWSRegion, DBCluster, TimeSpan.FromSeconds(450)), AccessKey, AWSRegion, DBConnectAdminActionValue, false, 450);
}
#endif

private void AssertAuthToken(string token, string accessKey, RegionEndpoint region, string actionValue)
{
AssertAuthToken(token, accessKey, region, actionValue, false);
AssertAuthToken(token, accessKey, region, actionValue, false, 900);
}

private void AssertAuthToken(string token, string accessKey, RegionEndpoint region, string actionValue, bool hasSessionToken)
{
AssertAuthToken(token, accessKey, region, actionValue, hasSessionToken, 900);
}

private void AssertAuthToken(string token, string accessKey, RegionEndpoint region, string actionValue, bool hasSessionToken, int expectedExpiresInSeconds)
{
// Look for today or yesterday to cover the crazy case where the
// token was generated utc yesterday but we're asserting utc today.
Expand All @@ -289,9 +360,9 @@ private void AssertAuthToken(string token, string accessKey, RegionEndpoint regi

var sessionTokenPart = hasSessionToken ? "X-Amz-Security-Token=" + AWSSDKUtils.UrlEncode(SessionToken, false) + "&" : "";
var regex = Regex.Escape(string.Format(CultureInfo.InvariantCulture,
"{0}/?Action={1}&X-Amz-Expires=900&{2}X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=" +
"{3}%2FTODAYREGEX%2F{4}%2Fdsql%2Faws4_request&X-Amz-Date=TODAYREGEXTTIMEREGEXZ&X-Amz-SignedHeaders=host&X-Amz-Signature=SIGNATUREREGEX",
DBCluster, actionValue, sessionTokenPart, accessKey, region.SystemName));
"{0}/?Action={1}&X-Amz-Expires={2}&{3}X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=" +
"{4}%2FTODAYREGEX%2F{5}%2Fdsql%2Faws4_request&X-Amz-Date=TODAYREGEXTTIMEREGEXZ&X-Amz-SignedHeaders=host&X-Amz-Signature=SIGNATUREREGEX",
DBCluster, actionValue, expectedExpiresInSeconds, sessionTokenPart, accessKey, region.SystemName));
regex = regex.Replace("TIMEREGEX", "[0-9]{6}").Replace("SIGNATUREREGEX", "[0-9a-f]{64}").Replace("TODAYREGEX", todayRegex);

Assert.IsTrue(Regex.IsMatch(token, regex), token + " doesn't match regex " + regex);
Expand Down