Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public static void RegisterInContainer<TDbContext>(
IServiceCollection services,
ResolveDbContext<TDbContext>? resolveDbContext = null,
IModel? model = null,
ResolveFilters? resolveFilters = null,
ResolveFilters<TDbContext>? resolveFilters = null,
bool disableTracking = false,
bool disableAsync = false)
```
Expand Down Expand Up @@ -92,9 +92,10 @@ A delegate that resolves the [Filters](filters.md).
```cs
namespace GraphQL.EntityFramework;

public delegate Filters? ResolveFilters(object userContext);
public delegate Filters<TDbContext>? ResolveFilters<TDbContext>(object userContext)
where TDbContext : DbContext;
```
<sup><a href='/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs#L1-L3' title='Snippet source file'>snippet source</a> | <a href='#snippet-ResolveFilters.cs' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs#L1-L4' title='Snippet source file'>snippet source</a> | <a href='#snippet-ResolveFilters.cs' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

It has access to the current GraphQL user context.
Expand All @@ -121,7 +122,7 @@ public static void RegisterInContainer<TDbContext>(
IServiceCollection services,
ResolveDbContext<TDbContext>? resolveDbContext = null,
IModel? model = null,
ResolveFilters? resolveFilters = null,
ResolveFilters<TDbContext>? resolveFilters = null,
bool disableTracking = false,
bool disableAsync = false)
```
Expand Down
13 changes: 7 additions & 6 deletions docs/filters.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ Notes:
<!-- snippet: FiltersSignature -->
<a id='snippet-FiltersSignature'></a>
```cs
public class Filters
public class Filters<TDbContext>
where TDbContext : DbContext
{
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
public delegate bool Filter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;

public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;
```
<sup><a href='/src/GraphQL.EntityFramework/Filters/Filters.cs#L3-L13' title='Snippet source file'>snippet source</a> | <a href='#snippet-FiltersSignature' title='Start of snippet'>anchor</a></sup>
<sup><a href='/src/GraphQL.EntityFramework/Filters/Filters.cs#L3-L14' title='Snippet source file'>snippet source</a> | <a href='#snippet-FiltersSignature' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->


Expand All @@ -51,9 +52,9 @@ public class MyEntity
<sup><a href='/src/Snippets/GlobalFilterSnippets.cs#L5-L12' title='Snippet source file'>snippet source</a> | <a href='#snippet-add-filter' title='Start of snippet'>anchor</a></sup>
<a id='snippet-add-filter-1'></a>
```cs
var filters = new Filters();
var filters = new Filters<MyDbContext>();
filters.Add<MyEntity>(
(userContext, userPrincipal, item) => item.Property != "Ignore");
(userContext, date, userPrincipal, item) => item.Property != "Ignore");
EfGraphQLConventions.RegisterInContainer<MyDbContext>(
services,
resolveFilters: _ => filters);
Expand Down
53 changes: 33 additions & 20 deletions src/GraphQL.EntityFramework/ConnectionConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,35 @@ static Connection<T> Range<T>(
return Build(skip, take, count, page);
}

public static Task<Connection<TItem>> ApplyConnectionContext<TSource, TItem>(
public static Task<Connection<TItem>> ApplyConnectionContext<TDbContext, TSource, TItem>(
this IQueryable<TItem> queryable,
int? first,
string afterString,
int? last,
string beforeString,
IResolveFieldContext<TSource> context,
Cancel cancel,
Filters filters)
Filters<TDbContext>? filters,
TDbContext data)
where TItem : class
where TDbContext : DbContext
{
Parse(afterString, beforeString, out var after, out var before);
return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel);
return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel, data);
}

public static async Task<Connection<TItem>> ApplyConnectionContext<TSource, TItem>(
public static async Task<Connection<TItem>> ApplyConnectionContext<TDbContext, TSource, TItem>(
IQueryable<TItem> queryable,
int? first,
int? after,
int? last,
int? before,
IResolveFieldContext<TSource> context,
Filters filters,
Cancel cancel = default)
Filters<TDbContext>? filters,
Cancel cancel,
TDbContext data)
where TItem : class
where TDbContext : DbContext
{
if (queryable is not IOrderedQueryable<TItem>)
{
Expand All @@ -102,22 +106,24 @@ public static async Task<Connection<TItem>> ApplyConnectionContext<TSource, TIte
cancel.ThrowIfCancellationRequested();
if (last is null)
{
return await First(queryable, first.GetValueOrDefault(0), after, before, count, context, filters, cancel);
return await First(queryable, first.GetValueOrDefault(0), after, before, count, context, filters, cancel, data);
}

return await Last(queryable, last.Value, after, before, count, context, filters, cancel);
return await Last(queryable, last.Value, after, before, count, context, filters, cancel, data);
}

static Task<Connection<TItem>> First<TSource, TItem>(
static Task<Connection<TItem>> First<TDbContext, TSource, TItem>(
IQueryable<TItem> queryable,
int first,
int? after,
int? before,
int count,
IResolveFieldContext<TSource> context,
Filters filters,
Cancel cancel)
Filters<TDbContext>? filters,
Cancel cancel,
TDbContext data)
where TItem : class
where TDbContext : DbContext
{
int skip;
if (before is null)
Expand All @@ -129,19 +135,21 @@ static Task<Connection<TItem>> First<TSource, TItem>(
skip = Math.Max(before.Value - first, 0);
}

return Range(queryable, skip, first, count, context, filters, cancel);
return Range(queryable, skip, first, count, context, filters, cancel, data);
}

static Task<Connection<TItem>> Last<TSource, TItem>(
static Task<Connection<TItem>> Last<TDbContext, TSource, TItem>(
IQueryable<TItem> queryable,
int last,
int? after,
int? before,
int count,
IResolveFieldContext<TSource> context,
Filters filters,
Cancel cancel)
Filters<TDbContext>? filters,
Cancel cancel,
TDbContext data)
where TItem : class
where TDbContext : DbContext
{
int skip;
if (after is null)
Expand All @@ -155,23 +163,28 @@ static Task<Connection<TItem>> Last<TSource, TItem>(
skip = after.Value + 1;
}

return Range(queryable, skip, take: last, count, context, filters, cancel);
return Range(queryable, skip, take: last, count, context, filters, cancel, data);
}

static async Task<Connection<TItem>> Range<TSource, TItem>(
static async Task<Connection<TItem>> Range<TDbContext, TSource, TItem>(
IQueryable<TItem> queryable,
int skip,
int take,
int count,
IResolveFieldContext<TSource> context,
Filters filters,
Cancel cancel)
Filters<TDbContext>? filters,
Cancel cancel,
TDbContext data)
where TItem : class
where TDbContext : DbContext
{
var page = queryable.Skip(skip).Take(take);
QueryLogger.Write(page);
IEnumerable<TItem> result = await page.ToListAsync(cancel);
result = await filters.ApplyFilter(result, context.UserContext, context.User);
if (filters != null)
{
result = await filters.ApplyFilter(result, context.UserContext, data, context.User);
}

cancel.ThrowIfCancellationRequested();
return Build(skip, take, count, result);
Expand Down
6 changes: 3 additions & 3 deletions src/GraphQL.EntityFramework/EfGraphQLConventions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static void RegisterInContainer<TDbContext>(
IServiceCollection services,
ResolveDbContext<TDbContext>? resolveDbContext = null,
IModel? model = null,
ResolveFilters? resolveFilters = null,
ResolveFilters<TDbContext>? resolveFilters = null,
bool disableTracking = false,
bool disableAsync = false)

Expand All @@ -37,14 +37,14 @@ public static void RegisterInContainer<TDbContext>(
static EfGraphQLService<TDbContext> Build<TDbContext>(
ResolveDbContext<TDbContext>? dbContextResolver,
IModel? model,
ResolveFilters? filters,
ResolveFilters<TDbContext>? filters,
IServiceProvider provider,
bool disableTracking,
bool disableAsync)
where TDbContext : DbContext
{
model ??= ResolveModel<TDbContext>(provider);
filters ??= provider.GetService<ResolveFilters>();
filters ??= provider.GetService<ResolveFilters<TDbContext>>();
dbContextResolver ??= _ => DbContextFromProvider<TDbContext>(provider);

return new(
Expand Down
31 changes: 16 additions & 15 deletions src/GraphQL.EntityFramework/Filters/Filters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@

#region FiltersSignature

public class Filters
public class Filters<TDbContext>
where TDbContext : DbContext
{
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
public delegate bool Filter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;

public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input)
where TEntity : class;

#endregion

public void Add<TEntity>(Filter<TEntity> filter)
where TEntity : class =>
funcs[typeof(TEntity)] =
(userContext, userPrincipal, item) =>
(userContext, data, userPrincipal, item) =>
{
try
{
return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item));
return Task.FromResult(filter(userContext, data, userPrincipal, (TEntity) item));
}
catch (Exception exception)
{
Expand All @@ -30,23 +31,23 @@ public void Add<TEntity>(Filter<TEntity> filter)
public void Add<TEntity>(AsyncFilter<TEntity> filter)
where TEntity : class =>
funcs[typeof(TEntity)] =
async (userContext, userPrincipal, item) =>
async (userContext, data, userPrincipal, item) =>
{
try
{
return await filter(userContext, userPrincipal, (TEntity) item);
return await filter(userContext, data, userPrincipal, (TEntity) item);
}
catch (Exception exception)
{
throw new($"Failed to execute filter. {nameof(TEntity)}: {typeof(TEntity)}.", exception);
}
};

delegate Task<bool> Filter(object userContext, ClaimsPrincipal? userPrincipal, object input);
delegate Task<bool> Filter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, object input);

Dictionary<Type, Filter> funcs = [];

internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal)
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, TDbContext data, ClaimsPrincipal? userPrincipal)
where TEntity : class
{
if (funcs.Count == 0)
Expand All @@ -63,7 +64,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
var list = new List<TEntity>();
foreach (var item in result)
{
if (await ShouldInclude(userContext, userPrincipal, item, filters))
if (await ShouldInclude(userContext, data, userPrincipal, item, filters))
{
list.Add(item);
}
Expand All @@ -72,12 +73,12 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
return list;
}

static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
static async Task<bool> ShouldInclude<TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
where TEntity : class
{
foreach (var func in filters)
{
if (!await func(userContext, userPrincipal, item))
if (!await func(userContext, data, userPrincipal, item))
{
return false;
}
Expand All @@ -86,7 +87,7 @@ static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincip
return true;
}

internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity? item)
where TEntity : class
{
if (item is null)
Expand All @@ -101,7 +102,7 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, Cla

foreach (var func in FindFilters<TEntity>())
{
if (!await func(userContext, userPrincipal, item))
if (!await func(userContext, data, userPrincipal, item))
{
return false;
}
Expand All @@ -116,7 +117,7 @@ IEnumerable<AsyncFilter<TEntity>> FindFilters<TEntity>()
var type = typeof(TEntity);
foreach (var pair in funcs.Where(_ => _.Key.IsAssignableFrom(type)))
{
yield return (context, user, item) => pair.Value(context, user, item);
yield return (context, data, user, item) => pair.Value(context, data, user, item);
}
}
}
14 changes: 0 additions & 14 deletions src/GraphQL.EntityFramework/Filters/NullFilters.cs

This file was deleted.

3 changes: 2 additions & 1 deletion src/GraphQL.EntityFramework/Filters/ResolveFilters.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
namespace GraphQL.EntityFramework;

public delegate Filters? ResolveFilters(object userContext);
public delegate Filters<TDbContext>? ResolveFilters<TDbContext>(object userContext)
where TDbContext : DbContext;
11 changes: 4 additions & 7 deletions src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public partial class EfGraphQLService<TDbContext> :
IEfGraphQLService<TDbContext>
where TDbContext : DbContext
{
ResolveFilters? resolveFilters;
ResolveFilters<TDbContext>? resolveFilters;
bool disableTracking;
bool disableAsync;
ResolveDbContext<TDbContext> resolveDbContext;
Expand All @@ -14,7 +14,7 @@ public partial class EfGraphQLService<TDbContext> :
public EfGraphQLService(
IModel model,
ResolveDbContext<TDbContext> resolveDbContext,
ResolveFilters? resolveFilters = null,
ResolveFilters<TDbContext>? resolveFilters = null,
bool disableTracking = false,
bool disableAsync = false)
{
Expand Down Expand Up @@ -68,11 +68,8 @@ ResolveEfFieldContext<TDbContext, TSource> BuildContext<TSource>(
public TDbContext ResolveDbContext(IResolveFieldContext context) =>
resolveDbContext(context.UserContext);

Filters ResolveFilter<TSource>(IResolveFieldContext<TSource> context)
{
var filter = resolveFilters?.Invoke(context.UserContext);
return filter ?? NullFilters.Instance;
}
Filters<TDbContext>? ResolveFilter<TSource>(IResolveFieldContext<TSource> context) =>
resolveFilters?.Invoke(context.UserContext);

public IQueryable<TItem> AddIncludes<TItem>(IQueryable<TItem> query, IResolveFieldContext context)
where TItem : class =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ FieldType BuildFirstField<TSource, TReturn>(

if (first is not null)
{
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, first))
if (efFieldContext.Filters == null ||
await efFieldContext.Filters.ShouldInclude(context.UserContext, efFieldContext.DbContext, context.User, first))
{
if (mutate is not null)
{
Expand Down
Loading