Skip to content

Commit 8c16de1

Browse files
authored
Optimize remote authenticate query (#729)
* Optimize the standardization for remote authentication query of Alipay and Wechat.
1 parent cc1b639 commit 8c16de1

File tree

2 files changed

+72
-18
lines changed

2 files changed

+72
-18
lines changed

src/AspNet.Security.OAuth.Alipay/AlipayAuthenticationHandler.cs

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.AspNetCore.WebUtilities;
1616
using Microsoft.Extensions.Logging;
1717
using Microsoft.Extensions.Options;
18+
using Microsoft.Extensions.Primitives;
1819

1920
namespace AspNet.Security.OAuth.Alipay;
2021

@@ -32,17 +33,13 @@ public AlipayAuthenticationHandler(
3233
{
3334
}
3435

36+
private const string AuthCode = "auth_code";
37+
3538
protected override Task<HandleRequestResult> HandleRemoteAuthenticateAsync()
3639
{
37-
var query = Request.Query;
38-
if (query.TryGetValue("auth_code", out var authCode))
40+
if (TryStandardizeRemoteAuthenticateQuery(Request.Query, out var queryString))
3941
{
40-
// The base `HandleRemoteAuthenticateAsync` requires that `Request.Query` must contain the key called `code`,
41-
// which is actually the same as `auth_code` by Alipay's design, but `Request.Query` does not have `Add` operation.
42-
// So here is a trick to get the private `Store` dictionary of `QueryCollection`.
43-
var queryStore = query.ToDictionary(c => c.Key, c => c.Value, StringComparer.OrdinalIgnoreCase);
44-
queryStore["code"] = authCode;
45-
Request.QueryString = QueryString.Create(queryStore);
42+
Request.QueryString = queryString;
4643
}
4744

4845
return base.HandleRemoteAuthenticateAsync();
@@ -236,6 +233,38 @@ protected override string BuildChallengeUrl([NotNull] AuthenticationProperties p
236233
return QueryHelpers.AddQueryString(Options.AuthorizationEndpoint, parameters);
237234
}
238235

236+
private static bool TryStandardizeRemoteAuthenticateQuery(IQueryCollection query, out QueryString queryString)
237+
{
238+
if (!query.TryGetValue(AuthCode, out var authCode))
239+
{
240+
queryString = default;
241+
return false;
242+
}
243+
244+
// Before: mydomain/signin-alipay?auth_code=xxx&state=xxx&...
245+
// After: mydomain/signin-alipay?code=xxx&state=xxx&...
246+
var queryParams = new List<KeyValuePair<string, StringValues>>(query.Count)
247+
{
248+
new("code", authCode)
249+
};
250+
foreach (var item in query)
251+
{
252+
switch (item.Key)
253+
{
254+
case "code":
255+
case AuthCode: // No need in fact, skip it
256+
break;
257+
258+
default:
259+
queryParams.Add(item);
260+
break;
261+
}
262+
}
263+
264+
queryString = QueryString.Create(queryParams);
265+
return true;
266+
}
267+
239268
private static partial class Log
240269
{
241270
internal static async Task UserProfileErrorAsync(ILogger logger, HttpResponseMessage response, CancellationToken cancellationToken)

src/AspNet.Security.OAuth.Weixin/WeixinAuthenticationHandler.cs

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
using Microsoft.AspNetCore.WebUtilities;
1515
using Microsoft.Extensions.Logging;
1616
using Microsoft.Extensions.Options;
17+
using Microsoft.Extensions.Primitives;
1718

1819
namespace AspNet.Security.OAuth.Weixin;
1920

@@ -33,17 +34,9 @@ public WeixinAuthenticationHandler(
3334

3435
protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync()
3536
{
36-
if (!IsWeixinAuthorizationEndpointInUse())
37+
if (!IsWeixinAuthorizationEndpointInUse() && TryStandardizeRemoteAuthenticateQuery(Request.Query, out var queryString))
3738
{
38-
if (Request.Query.TryGetValue(OauthState, out var stateValue))
39-
{
40-
var query = Request.Query.ToDictionary(c => c.Key, c => c.Value, StringComparer.OrdinalIgnoreCase);
41-
if (query.TryGetValue(State, out _))
42-
{
43-
query[State] = stateValue;
44-
Request.QueryString = QueryString.Create(query);
45-
}
46-
}
39+
Request.QueryString = queryString;
4740
}
4841

4942
return await base.HandleRemoteAuthenticateAsync();
@@ -198,6 +191,38 @@ private bool IsWeixinAuthorizationEndpointInUse()
198191
return string.Equals(Options.AuthorizationEndpoint, WeixinAuthenticationDefaults.AuthorizationEndpoint, StringComparison.OrdinalIgnoreCase);
199192
}
200193

194+
private static bool TryStandardizeRemoteAuthenticateQuery(IQueryCollection query, out QueryString queryString)
195+
{
196+
if (!query.TryGetValue(OauthState, out var actualState))
197+
{
198+
queryString = default;
199+
return false;
200+
}
201+
202+
// Before: mydomain/signin-weixin?code=xxx&state=_oauthstate&_oauthstate=<actual state>&...
203+
// After: mydomain/signin-weixin?code=xxx&state=<actual state>&...
204+
var queryParams = new List<KeyValuePair<string, StringValues>>(query.Count - 1);
205+
foreach (var item in query)
206+
{
207+
switch (item.Key)
208+
{
209+
case OauthState: // No need in fact, skip it
210+
break;
211+
212+
case State:
213+
queryParams.Add(new(State, actualState));
214+
break;
215+
216+
default:
217+
queryParams.Add(item);
218+
break;
219+
}
220+
}
221+
222+
queryString = QueryString.Create(queryParams);
223+
return true;
224+
}
225+
201226
private static partial class Log
202227
{
203228
internal static async Task ExchangeCodeErrorAsync(ILogger logger, HttpResponseMessage response, CancellationToken cancellationToken)

0 commit comments

Comments
 (0)