Skip to content

Commit 5094301

Browse files
committed
Merge branch 'master' of github.com:graphql-dotnet/authorization into develop
# Conflicts: # src/Directory.Build.props # src/GraphQL.Authorization/AuthorizationValidationRule.cs
2 parents 02ed319 + d0b411f commit 5094301

File tree

5 files changed

+103
-20
lines changed

5 files changed

+103
-20
lines changed

src/BasicSample/Program.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ type Query {
4444
// remove claims to see the failure
4545
var authorizedUser = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("role", "Admin") }));
4646

47-
string json = await schema.ExecuteAsync(_ =>
47+
string json = await schema.ExecuteAsync(options =>
4848
{
49-
_.Query = "{ viewer { id name } }";
50-
_.ValidationRules = serviceProvider
49+
options.Query = "{ viewer { id name } }";
50+
options.ValidationRules = serviceProvider
5151
.GetServices<IValidationRule>()
5252
.Concat(DocumentValidator.CoreRules);
53-
_.RequestServices = serviceProvider;
54-
_.UserContext = new GraphQLUserContext { User = authorizedUser };
53+
options.RequestServices = serviceProvider;
54+
options.UserContext = new GraphQLUserContext { User = authorizedUser };
5555
});
5656

5757
Console.WriteLine(json);

src/GraphQL.Authorization.ApiTests/GraphQL.Authorization.approved.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ namespace GraphQL.Authorization
6363
public ClaimAuthorizationRequirement(string claimType, System.Collections.Generic.IEnumerable<string> allowedValues) { }
6464
public ClaimAuthorizationRequirement(string claimType, params string[] allowedValues) { }
6565
public ClaimAuthorizationRequirement(string claimType, System.Collections.Generic.IEnumerable<string> allowedValues, System.Collections.Generic.IEnumerable<string> displayValues) { }
66+
public System.Collections.Generic.IEnumerable<string> AllowedValues { get; }
67+
public string ClaimType { get; }
68+
public System.Collections.Generic.IEnumerable<string> DisplayValues { get; }
6669
public System.Threading.Tasks.Task Authorize(GraphQL.Authorization.AuthorizationContext context) { }
6770
}
6871
public interface IAuthorizationEvaluator

src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,48 @@ public void Issue61()
223223
});
224224
}
225225

226+
[Fact]
227+
public void passes_with_claim_on_variable_type()
228+
{
229+
Settings.AddPolicy("FieldPolicy", builder => builder.RequireClaim("admin"));
230+
231+
ShouldPassRule(config =>
232+
{
233+
config.Query = @"query Author($input: AuthorInputType!) { author(input: $input) }";
234+
config.Schema = TypedSchema();
235+
config.Inputs = new Inputs(new Dictionary<string, object>()
236+
{
237+
{
238+
"input",
239+
new Dictionary<string,object>{ { "name","Quinn" } }
240+
}
241+
});
242+
config.User = CreatePrincipal(claims: new Dictionary<string, string>
243+
{
244+
{ "Admin", "true" }
245+
});
246+
});
247+
}
248+
249+
[Fact]
250+
public void fails_on_missing_claim_on_variable_type()
251+
{
252+
Settings.AddPolicy("FieldPolicy", builder => builder.RequireClaim("admin"));
253+
254+
ShouldFailRule(config =>
255+
{
256+
config.Query = @"query Author($input: AuthorInputType!) { author(input: $input) }";
257+
config.Schema = TypedSchema();
258+
config.Inputs = new Inputs(new Dictionary<string, object>()
259+
{
260+
{
261+
"input",
262+
new Dictionary<string,object>{ { "name","Quinn" } }
263+
}
264+
});
265+
});
266+
}
267+
226268
[Fact]
227269
public void passes_with_policy_on_connection_type()
228270
{

src/GraphQL.Authorization/AuthorizationValidationRule.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Collections.Generic;
12
using System.Linq;
23
using System.Threading.Tasks;
34
using GraphQL.Language.AST;
@@ -121,6 +122,30 @@ public ValueTask<INodeVisitor> ValidateAsync(ValidationContext context)
121122
CheckAuth(fieldAst, fieldDef, userContext, context, operationType);
122123
// check returned graph type
123124
CheckAuth(fieldAst, fieldDef.ResolvedType.GetNamedType(), userContext, context, operationType);
125+
}),
126+
127+
new MatchingNodeVisitor<VariableReference>((variableRef, context) =>
128+
{
129+
if (!(context.TypeInfo.GetArgument().ResolvedType.GetNamedType() is IComplexGraphType variableType))
130+
return;
131+
132+
CheckAuth(variableRef, variableType, userContext, context, operationType);
133+
134+
// Check each supplied field in the variable that exists in the variable type.
135+
// If some supplied field does not exist in the variable type then some other
136+
// validation rule should check that but here we should just ignore that
137+
// "unknown" field.
138+
if (context.Inputs.TryGetValue(variableRef.Name, out object input) &&
139+
input is Dictionary<string, object> fieldsValues)
140+
{
141+
foreach (var field in variableType.Fields)
142+
{
143+
if (fieldsValues.ContainsKey(field.Name))
144+
{
145+
CheckAuth(variableRef, field, userContext, context, operationType);
146+
}
147+
}
148+
}
124149
})
125150
));
126151
}

src/GraphQL.Authorization/Requirements/ClaimAuthorizationRequirement.cs

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ namespace GraphQL.Authorization
1111
/// </summary>
1212
public class ClaimAuthorizationRequirement : IAuthorizationRequirement
1313
{
14-
private readonly string _claimType;
15-
private readonly IEnumerable<string> _displayValues;
16-
private readonly IEnumerable<string> _allowedValues;
17-
1814
/// <summary>
1915
/// Creates a new instance of <see cref="ClaimAuthorizationRequirement"/> with
2016
/// the specified claim type.
@@ -53,41 +49,58 @@ public ClaimAuthorizationRequirement(string claimType, params string[] allowedVa
5349
/// </summary>
5450
public ClaimAuthorizationRequirement(string claimType, IEnumerable<string> allowedValues, IEnumerable<string> displayValues)
5551
{
56-
_claimType = claimType ?? throw new ArgumentNullException(nameof(claimType));
57-
_allowedValues = allowedValues ?? Enumerable.Empty<string>();
58-
_displayValues = displayValues;
52+
ClaimType = claimType ?? throw new ArgumentNullException(nameof(claimType));
53+
AllowedValues = allowedValues ?? Enumerable.Empty<string>();
54+
DisplayValues = displayValues;
5955
}
6056

57+
/// <summary>
58+
/// Claim type that claims principal from <see cref="AuthorizationContext"/> should have.
59+
/// </summary>
60+
public string ClaimType { get; }
61+
62+
/// <summary>
63+
/// List of claim values, which, if present, the claim must match.
64+
/// </summary>
65+
public IEnumerable<string> AllowedValues { get; }
66+
67+
/// <summary>
68+
/// Specifies the set of displayed claim values that will be used
69+
/// to generate an error message if the requirement is not met.
70+
/// If null then values from <see cref="AllowedValues"/> are used.
71+
/// </summary>
72+
public IEnumerable<string> DisplayValues { get; }
73+
6174
/// <inheritdoc />
6275
public Task Authorize(AuthorizationContext context)
6376
{
6477
bool found = false;
6578

6679
if (context.User != null)
6780
{
68-
if (_allowedValues == null || !_allowedValues.Any())
81+
if (AllowedValues == null || !AllowedValues.Any())
6982
{
7083
found = context.User.Claims.Any(
71-
claim => string.Equals(claim.Type, _claimType, StringComparison.OrdinalIgnoreCase));
84+
claim => string.Equals(claim.Type, ClaimType, StringComparison.OrdinalIgnoreCase));
7285
}
7386
else
7487
{
7588
found = context.User.Claims.Any(
76-
claim => string.Equals(claim.Type, _claimType, StringComparison.OrdinalIgnoreCase)
77-
&& _allowedValues.Contains(claim.Value, StringComparer.Ordinal));
89+
claim => string.Equals(claim.Type, ClaimType, StringComparison.OrdinalIgnoreCase)
90+
&& AllowedValues.Contains(claim.Value, StringComparer.Ordinal));
7891
}
7992
}
8093

8194
if (!found)
8295
{
83-
if (_allowedValues != null && _allowedValues.Any())
96+
if (AllowedValues != null && AllowedValues.Any())
8497
{
85-
string values = string.Join(", ", _displayValues ?? _allowedValues);
86-
context.ReportError($"Required claim '{_claimType}' with any value of '{values}' is not present.");
98+
string values = string.Join(", ", DisplayValues ?? AllowedValues);
99+
context.ReportError($"Required claim '{ClaimType}' with any value of '{values}' is not present.");
87100
}
88101
else
89102
{
90-
context.ReportError($"Required claim '{_claimType}' is not present.");
103+
context.ReportError($"Required claim '{ClaimType}' is not present.");
91104
}
92105
}
93106

0 commit comments

Comments
 (0)