Skip to content

Commit 652065f

Browse files
authored
Pass DbContext to Filters (#1089)
1 parent 9c3c0fb commit 652065f

22 files changed

+136
-123
lines changed

docs/configuration.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public static void RegisterInContainer<TDbContext>(
2121
IServiceCollection services,
2222
ResolveDbContext<TDbContext>? resolveDbContext = null,
2323
IModel? model = null,
24-
ResolveFilters? resolveFilters = null,
24+
ResolveFilters<TDbContext>? resolveFilters = null,
2525
bool disableTracking = false,
2626
bool disableAsync = false)
2727
```
@@ -92,9 +92,10 @@ A delegate that resolves the [Filters](filters.md).
9292
```cs
9393
namespace GraphQL.EntityFramework;
9494

95-
public delegate Filters? ResolveFilters(object userContext);
95+
public delegate Filters<TDbContext>? ResolveFilters<TDbContext>(object userContext)
96+
where TDbContext : DbContext;
9697
```
97-
<sup><a href='/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs#L1-L3' title='Snippet source file'>snippet source</a> | <a href='#snippet-ResolveFilters.cs' title='Start of snippet'>anchor</a></sup>
98+
<sup><a href='/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs#L1-L4' title='Snippet source file'>snippet source</a> | <a href='#snippet-ResolveFilters.cs' title='Start of snippet'>anchor</a></sup>
9899
<!-- endSnippet -->
99100

100101
It has access to the current GraphQL user context.
@@ -121,7 +122,7 @@ public static void RegisterInContainer<TDbContext>(
121122
IServiceCollection services,
122123
ResolveDbContext<TDbContext>? resolveDbContext = null,
123124
IModel? model = null,
124-
ResolveFilters? resolveFilters = null,
125+
ResolveFilters<TDbContext>? resolveFilters = null,
125126
bool disableTracking = false,
126127
bool disableAsync = false)
127128
```

docs/filters.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ Notes:
2626
<!-- snippet: FiltersSignature -->
2727
<a id='snippet-FiltersSignature'></a>
2828
```cs
29-
public class Filters
29+
public class Filters<TDbContext>
30+
where TDbContext : DbContext
3031
{
31-
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
32+
public delegate bool Filter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
3233
where TEntity : class;
3334

34-
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
35+
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
3536
where TEntity : class;
3637
```
37-
<sup><a href='/src/GraphQL.EntityFramework/Filters/Filters.cs#L3-L13' title='Snippet source file'>snippet source</a> | <a href='#snippet-FiltersSignature' title='Start of snippet'>anchor</a></sup>
38+
<sup><a href='/src/GraphQL.EntityFramework/Filters/Filters.cs#L3-L14' title='Snippet source file'>snippet source</a> | <a href='#snippet-FiltersSignature' title='Start of snippet'>anchor</a></sup>
3839
<!-- endSnippet -->
3940

4041

@@ -51,9 +52,9 @@ public class MyEntity
5152
<sup><a href='/src/Snippets/GlobalFilterSnippets.cs#L5-L12' title='Snippet source file'>snippet source</a> | <a href='#snippet-add-filter' title='Start of snippet'>anchor</a></sup>
5253
<a id='snippet-add-filter-1'></a>
5354
```cs
54-
var filters = new Filters();
55+
var filters = new Filters<MyDbContext>();
5556
filters.Add<MyEntity>(
56-
(userContext, userPrincipal, item) => item.Property != "Ignore");
57+
(userContext, date, userPrincipal, item) => item.Property != "Ignore");
5758
EfGraphQLConventions.RegisterInContainer<MyDbContext>(
5859
services,
5960
resolveFilters: _ => filters);

src/GraphQL.EntityFramework/ConnectionConverter.cs

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,31 +68,35 @@ static Connection<T> Range<T>(
6868
return Build(skip, take, count, page);
6969
}
7070

71-
public static Task<Connection<TItem>> ApplyConnectionContext<TSource, TItem>(
71+
public static Task<Connection<TItem>> ApplyConnectionContext<TDbContext, TSource, TItem>(
7272
this IQueryable<TItem> queryable,
7373
int? first,
7474
string afterString,
7575
int? last,
7676
string beforeString,
7777
IResolveFieldContext<TSource> context,
7878
Cancel cancel,
79-
Filters filters)
79+
Filters<TDbContext>? filters,
80+
TDbContext data)
8081
where TItem : class
82+
where TDbContext : DbContext
8183
{
8284
Parse(afterString, beforeString, out var after, out var before);
83-
return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel);
85+
return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel, data);
8486
}
8587

86-
public static async Task<Connection<TItem>> ApplyConnectionContext<TSource, TItem>(
88+
public static async Task<Connection<TItem>> ApplyConnectionContext<TDbContext, TSource, TItem>(
8789
IQueryable<TItem> queryable,
8890
int? first,
8991
int? after,
9092
int? last,
9193
int? before,
9294
IResolveFieldContext<TSource> context,
93-
Filters filters,
94-
Cancel cancel = default)
95+
Filters<TDbContext>? filters,
96+
Cancel cancel,
97+
TDbContext data)
9598
where TItem : class
99+
where TDbContext : DbContext
96100
{
97101
if (queryable is not IOrderedQueryable<TItem>)
98102
{
@@ -102,22 +106,24 @@ public static async Task<Connection<TItem>> ApplyConnectionContext<TSource, TIte
102106
cancel.ThrowIfCancellationRequested();
103107
if (last is null)
104108
{
105-
return await First(queryable, first.GetValueOrDefault(0), after, before, count, context, filters, cancel);
109+
return await First(queryable, first.GetValueOrDefault(0), after, before, count, context, filters, cancel, data);
106110
}
107111

108-
return await Last(queryable, last.Value, after, before, count, context, filters, cancel);
112+
return await Last(queryable, last.Value, after, before, count, context, filters, cancel, data);
109113
}
110114

111-
static Task<Connection<TItem>> First<TSource, TItem>(
115+
static Task<Connection<TItem>> First<TDbContext, TSource, TItem>(
112116
IQueryable<TItem> queryable,
113117
int first,
114118
int? after,
115119
int? before,
116120
int count,
117121
IResolveFieldContext<TSource> context,
118-
Filters filters,
119-
Cancel cancel)
122+
Filters<TDbContext>? filters,
123+
Cancel cancel,
124+
TDbContext data)
120125
where TItem : class
126+
where TDbContext : DbContext
121127
{
122128
int skip;
123129
if (before is null)
@@ -129,19 +135,21 @@ static Task<Connection<TItem>> First<TSource, TItem>(
129135
skip = Math.Max(before.Value - first, 0);
130136
}
131137

132-
return Range(queryable, skip, first, count, context, filters, cancel);
138+
return Range(queryable, skip, first, count, context, filters, cancel, data);
133139
}
134140

135-
static Task<Connection<TItem>> Last<TSource, TItem>(
141+
static Task<Connection<TItem>> Last<TDbContext, TSource, TItem>(
136142
IQueryable<TItem> queryable,
137143
int last,
138144
int? after,
139145
int? before,
140146
int count,
141147
IResolveFieldContext<TSource> context,
142-
Filters filters,
143-
Cancel cancel)
148+
Filters<TDbContext>? filters,
149+
Cancel cancel,
150+
TDbContext data)
144151
where TItem : class
152+
where TDbContext : DbContext
145153
{
146154
int skip;
147155
if (after is null)
@@ -155,23 +163,28 @@ static Task<Connection<TItem>> Last<TSource, TItem>(
155163
skip = after.Value + 1;
156164
}
157165

158-
return Range(queryable, skip, take: last, count, context, filters, cancel);
166+
return Range(queryable, skip, take: last, count, context, filters, cancel, data);
159167
}
160168

161-
static async Task<Connection<TItem>> Range<TSource, TItem>(
169+
static async Task<Connection<TItem>> Range<TDbContext, TSource, TItem>(
162170
IQueryable<TItem> queryable,
163171
int skip,
164172
int take,
165173
int count,
166174
IResolveFieldContext<TSource> context,
167-
Filters filters,
168-
Cancel cancel)
175+
Filters<TDbContext>? filters,
176+
Cancel cancel,
177+
TDbContext data)
169178
where TItem : class
179+
where TDbContext : DbContext
170180
{
171181
var page = queryable.Skip(skip).Take(take);
172182
QueryLogger.Write(page);
173183
IEnumerable<TItem> result = await page.ToListAsync(cancel);
174-
result = await filters.ApplyFilter(result, context.UserContext, context.User);
184+
if (filters != null)
185+
{
186+
result = await filters.ApplyFilter(result, context.UserContext, data, context.User);
187+
}
175188

176189
cancel.ThrowIfCancellationRequested();
177190
return Build(skip, take, count, result);

src/GraphQL.EntityFramework/EfGraphQLConventions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static void RegisterInContainer<TDbContext>(
1717
IServiceCollection services,
1818
ResolveDbContext<TDbContext>? resolveDbContext = null,
1919
IModel? model = null,
20-
ResolveFilters? resolveFilters = null,
20+
ResolveFilters<TDbContext>? resolveFilters = null,
2121
bool disableTracking = false,
2222
bool disableAsync = false)
2323

@@ -37,14 +37,14 @@ public static void RegisterInContainer<TDbContext>(
3737
static EfGraphQLService<TDbContext> Build<TDbContext>(
3838
ResolveDbContext<TDbContext>? dbContextResolver,
3939
IModel? model,
40-
ResolveFilters? filters,
40+
ResolveFilters<TDbContext>? filters,
4141
IServiceProvider provider,
4242
bool disableTracking,
4343
bool disableAsync)
4444
where TDbContext : DbContext
4545
{
4646
model ??= ResolveModel<TDbContext>(provider);
47-
filters ??= provider.GetService<ResolveFilters>();
47+
filters ??= provider.GetService<ResolveFilters<TDbContext>>();
4848
dbContextResolver ??= _ => DbContextFromProvider<TDbContext>(provider);
4949

5050
return new(

src/GraphQL.EntityFramework/Filters/Filters.cs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22

33
#region FiltersSignature
44

5-
public class Filters
5+
public class Filters<TDbContext>
6+
where TDbContext : DbContext
67
{
7-
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
8+
public delegate bool Filter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
89
where TEntity : class;
910

10-
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
11+
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
1112
where TEntity : class;
1213

1314
#endregion
1415

1516
public void Add<TEntity>(Filter<TEntity> filter)
1617
where TEntity : class =>
1718
funcs[typeof(TEntity)] =
18-
(userContext, userPrincipal, item) =>
19+
(userContext, data, userPrincipal, item) =>
1920
{
2021
try
2122
{
22-
return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item));
23+
return Task.FromResult(filter(userContext, data, userPrincipal, (TEntity) item));
2324
}
2425
catch (Exception exception)
2526
{
@@ -30,23 +31,23 @@ public void Add<TEntity>(Filter<TEntity> filter)
3031
public void Add<TEntity>(AsyncFilter<TEntity> filter)
3132
where TEntity : class =>
3233
funcs[typeof(TEntity)] =
33-
async (userContext, userPrincipal, item) =>
34+
async (userContext, data, userPrincipal, item) =>
3435
{
3536
try
3637
{
37-
return await filter(userContext, userPrincipal, (TEntity) item);
38+
return await filter(userContext, data, userPrincipal, (TEntity) item);
3839
}
3940
catch (Exception exception)
4041
{
4142
throw new($"Failed to execute filter. {nameof(TEntity)}: {typeof(TEntity)}.", exception);
4243
}
4344
};
4445

45-
delegate Task<bool> Filter(object userContext, ClaimsPrincipal? userPrincipal, object input);
46+
delegate Task<bool> Filter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, object input);
4647

4748
Dictionary<Type, Filter> funcs = [];
4849

49-
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal)
50+
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, TDbContext data, ClaimsPrincipal? userPrincipal)
5051
where TEntity : class
5152
{
5253
if (funcs.Count == 0)
@@ -63,7 +64,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
6364
var list = new List<TEntity>();
6465
foreach (var item in result)
6566
{
66-
if (await ShouldInclude(userContext, userPrincipal, item, filters))
67+
if (await ShouldInclude(userContext, data, userPrincipal, item, filters))
6768
{
6869
list.Add(item);
6970
}
@@ -72,12 +73,12 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
7273
return list;
7374
}
7475

75-
static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
76+
static async Task<bool> ShouldInclude<TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
7677
where TEntity : class
7778
{
7879
foreach (var func in filters)
7980
{
80-
if (!await func(userContext, userPrincipal, item))
81+
if (!await func(userContext, data, userPrincipal, item))
8182
{
8283
return false;
8384
}
@@ -86,7 +87,7 @@ static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincip
8687
return true;
8788
}
8889

89-
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
90+
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity? item)
9091
where TEntity : class
9192
{
9293
if (item is null)
@@ -101,7 +102,7 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, Cla
101102

102103
foreach (var func in FindFilters<TEntity>())
103104
{
104-
if (!await func(userContext, userPrincipal, item))
105+
if (!await func(userContext, data, userPrincipal, item))
105106
{
106107
return false;
107108
}
@@ -116,7 +117,7 @@ IEnumerable<AsyncFilter<TEntity>> FindFilters<TEntity>()
116117
var type = typeof(TEntity);
117118
foreach (var pair in funcs.Where(_ => _.Key.IsAssignableFrom(type)))
118119
{
119-
yield return (context, user, item) => pair.Value(context, user, item);
120+
yield return (context, data, user, item) => pair.Value(context, data, user, item);
120121
}
121122
}
122123
}

src/GraphQL.EntityFramework/Filters/NullFilters.cs

Lines changed: 0 additions & 14 deletions
This file was deleted.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
namespace GraphQL.EntityFramework;
22

3-
public delegate Filters? ResolveFilters(object userContext);
3+
public delegate Filters<TDbContext>? ResolveFilters<TDbContext>(object userContext)
4+
where TDbContext : DbContext;

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ public partial class EfGraphQLService<TDbContext> :
44
IEfGraphQLService<TDbContext>
55
where TDbContext : DbContext
66
{
7-
ResolveFilters? resolveFilters;
7+
ResolveFilters<TDbContext>? resolveFilters;
88
bool disableTracking;
99
bool disableAsync;
1010
ResolveDbContext<TDbContext> resolveDbContext;
@@ -14,7 +14,7 @@ public partial class EfGraphQLService<TDbContext> :
1414
public EfGraphQLService(
1515
IModel model,
1616
ResolveDbContext<TDbContext> resolveDbContext,
17-
ResolveFilters? resolveFilters = null,
17+
ResolveFilters<TDbContext>? resolveFilters = null,
1818
bool disableTracking = false,
1919
bool disableAsync = false)
2020
{
@@ -68,11 +68,8 @@ ResolveEfFieldContext<TDbContext, TSource> BuildContext<TSource>(
6868
public TDbContext ResolveDbContext(IResolveFieldContext context) =>
6969
resolveDbContext(context.UserContext);
7070

71-
Filters ResolveFilter<TSource>(IResolveFieldContext<TSource> context)
72-
{
73-
var filter = resolveFilters?.Invoke(context.UserContext);
74-
return filter ?? NullFilters.Instance;
75-
}
71+
Filters<TDbContext>? ResolveFilter<TSource>(IResolveFieldContext<TSource> context) =>
72+
resolveFilters?.Invoke(context.UserContext);
7673

7774
public IQueryable<TItem> AddIncludes<TItem>(IQueryable<TItem> query, IResolveFieldContext context)
7875
where TItem : class =>

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_First.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ FieldType BuildFirstField<TSource, TReturn>(
195195

196196
if (first is not null)
197197
{
198-
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, first))
198+
if (efFieldContext.Filters == null ||
199+
await efFieldContext.Filters.ShouldInclude(context.UserContext, efFieldContext.DbContext, context.User, first))
199200
{
200201
if (mutate is not null)
201202
{

0 commit comments

Comments
 (0)