diff --git a/docs/configuration.md b/docs/configuration.md index ef7847e60..b20865339 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,7 +21,7 @@ public static void RegisterInContainer( IServiceCollection services, ResolveDbContext? resolveDbContext = null, IModel? model = null, - ResolveFilters? resolveFilters = null, + ResolveFilters? resolveFilters = null, bool disableTracking = false, bool disableAsync = false) ``` @@ -92,9 +92,10 @@ A delegate that resolves the [Filters](filters.md). ```cs namespace GraphQL.EntityFramework; -public delegate Filters? ResolveFilters(object userContext); +public delegate Filters? ResolveFilters(object userContext) + where TDbContext : DbContext; ``` -snippet source | anchor +snippet source | anchor It has access to the current GraphQL user context. @@ -121,7 +122,7 @@ public static void RegisterInContainer( IServiceCollection services, ResolveDbContext? resolveDbContext = null, IModel? model = null, - ResolveFilters? resolveFilters = null, + ResolveFilters? resolveFilters = null, bool disableTracking = false, bool disableAsync = false) ``` diff --git a/docs/filters.md b/docs/filters.md index 21dffad3e..435e86902 100644 --- a/docs/filters.md +++ b/docs/filters.md @@ -26,15 +26,16 @@ Notes: ```cs -public class Filters +public class Filters + where TDbContext : DbContext { - public delegate bool Filter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) + public delegate bool Filter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; - public delegate Task AsyncFilter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) + public delegate Task AsyncFilter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; ``` -snippet source | anchor +snippet source | anchor @@ -51,9 +52,9 @@ public class MyEntity snippet source | anchor ```cs -var filters = new Filters(); +var filters = new Filters(); filters.Add( - (userContext, userPrincipal, item) => item.Property != "Ignore"); + (userContext, date, userPrincipal, item) => item.Property != "Ignore"); EfGraphQLConventions.RegisterInContainer( services, resolveFilters: _ => filters); diff --git a/src/GraphQL.EntityFramework/ConnectionConverter.cs b/src/GraphQL.EntityFramework/ConnectionConverter.cs index 9765acd6b..80302220a 100644 --- a/src/GraphQL.EntityFramework/ConnectionConverter.cs +++ b/src/GraphQL.EntityFramework/ConnectionConverter.cs @@ -68,7 +68,7 @@ static Connection Range( return Build(skip, take, count, page); } - public static Task> ApplyConnectionContext( + public static Task> ApplyConnectionContext( this IQueryable queryable, int? first, string afterString, @@ -76,23 +76,27 @@ public static Task> ApplyConnectionContext( string beforeString, IResolveFieldContext context, Cancel cancel, - Filters filters) + Filters? filters, + TDbContext data) where TItem : class + where TDbContext : DbContext { Parse(afterString, beforeString, out var after, out var before); - return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel); + return ApplyConnectionContext(queryable, first, after, last, before, context, filters, cancel, data); } - public static async Task> ApplyConnectionContext( + public static async Task> ApplyConnectionContext( IQueryable queryable, int? first, int? after, int? last, int? before, IResolveFieldContext context, - Filters filters, - Cancel cancel = default) + Filters? filters, + Cancel cancel, + TDbContext data) where TItem : class + where TDbContext : DbContext { if (queryable is not IOrderedQueryable) { @@ -102,22 +106,24 @@ public static async Task> ApplyConnectionContext> First( + static Task> First( IQueryable queryable, int first, int? after, int? before, int count, IResolveFieldContext context, - Filters filters, - Cancel cancel) + Filters? filters, + Cancel cancel, + TDbContext data) where TItem : class + where TDbContext : DbContext { int skip; if (before is null) @@ -129,19 +135,21 @@ static Task> First( skip = Math.Max(before.Value - first, 0); } - return Range(queryable, skip, first, count, context, filters, cancel); + return Range(queryable, skip, first, count, context, filters, cancel, data); } - static Task> Last( + static Task> Last( IQueryable queryable, int last, int? after, int? before, int count, IResolveFieldContext context, - Filters filters, - Cancel cancel) + Filters? filters, + Cancel cancel, + TDbContext data) where TItem : class + where TDbContext : DbContext { int skip; if (after is null) @@ -155,23 +163,28 @@ static Task> Last( skip = after.Value + 1; } - return Range(queryable, skip, take: last, count, context, filters, cancel); + return Range(queryable, skip, take: last, count, context, filters, cancel, data); } - static async Task> Range( + static async Task> Range( IQueryable queryable, int skip, int take, int count, IResolveFieldContext context, - Filters filters, - Cancel cancel) + Filters? filters, + Cancel cancel, + TDbContext data) where TItem : class + where TDbContext : DbContext { var page = queryable.Skip(skip).Take(take); QueryLogger.Write(page); IEnumerable result = await page.ToListAsync(cancel); - result = await filters.ApplyFilter(result, context.UserContext, context.User); + if (filters != null) + { + result = await filters.ApplyFilter(result, context.UserContext, data, context.User); + } cancel.ThrowIfCancellationRequested(); return Build(skip, take, count, result); diff --git a/src/GraphQL.EntityFramework/EfGraphQLConventions.cs b/src/GraphQL.EntityFramework/EfGraphQLConventions.cs index a42beb3de..215d2e0b9 100644 --- a/src/GraphQL.EntityFramework/EfGraphQLConventions.cs +++ b/src/GraphQL.EntityFramework/EfGraphQLConventions.cs @@ -17,7 +17,7 @@ public static void RegisterInContainer( IServiceCollection services, ResolveDbContext? resolveDbContext = null, IModel? model = null, - ResolveFilters? resolveFilters = null, + ResolveFilters? resolveFilters = null, bool disableTracking = false, bool disableAsync = false) @@ -37,14 +37,14 @@ public static void RegisterInContainer( static EfGraphQLService Build( ResolveDbContext? dbContextResolver, IModel? model, - ResolveFilters? filters, + ResolveFilters? filters, IServiceProvider provider, bool disableTracking, bool disableAsync) where TDbContext : DbContext { model ??= ResolveModel(provider); - filters ??= provider.GetService(); + filters ??= provider.GetService>(); dbContextResolver ??= _ => DbContextFromProvider(provider); return new( diff --git a/src/GraphQL.EntityFramework/Filters/Filters.cs b/src/GraphQL.EntityFramework/Filters/Filters.cs index c24e9f084..7a29a8c9f 100644 --- a/src/GraphQL.EntityFramework/Filters/Filters.cs +++ b/src/GraphQL.EntityFramework/Filters/Filters.cs @@ -2,12 +2,13 @@ #region FiltersSignature -public class Filters +public class Filters + where TDbContext : DbContext { - public delegate bool Filter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) + public delegate bool Filter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; - public delegate Task AsyncFilter(object userContext, ClaimsPrincipal? userPrincipal, TEntity input) + public delegate Task AsyncFilter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity input) where TEntity : class; #endregion @@ -15,11 +16,11 @@ public delegate Task AsyncFilter(object userContext, ClaimsPri public void Add(Filter filter) where TEntity : class => funcs[typeof(TEntity)] = - (userContext, userPrincipal, item) => + (userContext, data, userPrincipal, item) => { try { - return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item)); + return Task.FromResult(filter(userContext, data, userPrincipal, (TEntity) item)); } catch (Exception exception) { @@ -30,11 +31,11 @@ public void Add(Filter filter) public void Add(AsyncFilter filter) where TEntity : class => funcs[typeof(TEntity)] = - async (userContext, userPrincipal, item) => + async (userContext, data, userPrincipal, item) => { try { - return await filter(userContext, userPrincipal, (TEntity) item); + return await filter(userContext, data, userPrincipal, (TEntity) item); } catch (Exception exception) { @@ -42,11 +43,11 @@ public void Add(AsyncFilter filter) } }; - delegate Task Filter(object userContext, ClaimsPrincipal? userPrincipal, object input); + delegate Task Filter(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, object input); Dictionary funcs = []; - internal virtual async Task> ApplyFilter(IEnumerable result, object userContext, ClaimsPrincipal? userPrincipal) + internal virtual async Task> ApplyFilter(IEnumerable result, object userContext, TDbContext data, ClaimsPrincipal? userPrincipal) where TEntity : class { if (funcs.Count == 0) @@ -63,7 +64,7 @@ internal virtual async Task> ApplyFilter(IEnumerab var list = new List(); foreach (var item in result) { - if (await ShouldInclude(userContext, userPrincipal, item, filters)) + if (await ShouldInclude(userContext, data, userPrincipal, item, filters)) { list.Add(item); } @@ -72,12 +73,12 @@ internal virtual async Task> ApplyFilter(IEnumerab return list; } - static async Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List> filters) + static async Task ShouldInclude(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity item, List> filters) where TEntity : class { foreach (var func in filters) { - if (!await func(userContext, userPrincipal, item)) + if (!await func(userContext, data, userPrincipal, item)) { return false; } @@ -86,7 +87,7 @@ static async Task ShouldInclude(object userContext, ClaimsPrincip return true; } - internal virtual async Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item) + internal virtual async Task ShouldInclude(object userContext, TDbContext data, ClaimsPrincipal? userPrincipal, TEntity? item) where TEntity : class { if (item is null) @@ -101,7 +102,7 @@ internal virtual async Task ShouldInclude(object userContext, Cla foreach (var func in FindFilters()) { - if (!await func(userContext, userPrincipal, item)) + if (!await func(userContext, data, userPrincipal, item)) { return false; } @@ -116,7 +117,7 @@ IEnumerable> FindFilters() var type = typeof(TEntity); foreach (var pair in funcs.Where(_ => _.Key.IsAssignableFrom(type))) { - yield return (context, user, item) => pair.Value(context, user, item); + yield return (context, data, user, item) => pair.Value(context, data, user, item); } } } \ No newline at end of file diff --git a/src/GraphQL.EntityFramework/Filters/NullFilters.cs b/src/GraphQL.EntityFramework/Filters/NullFilters.cs deleted file mode 100644 index db034f4f5..000000000 --- a/src/GraphQL.EntityFramework/Filters/NullFilters.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace GraphQL.EntityFramework; - -public class NullFilters : - Filters -{ - public static NullFilters Instance = new(); - - internal override Task> ApplyFilter(IEnumerable result, object userContext, ClaimsPrincipal? userPrincipal) => - Task.FromResult(result); - - internal override Task ShouldInclude(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item) - where TEntity : class => - Task.FromResult(true); -} \ No newline at end of file diff --git a/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs b/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs index d5c7b7344..a3b487cee 100644 --- a/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs +++ b/src/GraphQL.EntityFramework/Filters/ResolveFilters.cs @@ -1,3 +1,4 @@ namespace GraphQL.EntityFramework; -public delegate Filters? ResolveFilters(object userContext); \ No newline at end of file +public delegate Filters? ResolveFilters(object userContext) + where TDbContext : DbContext; \ No newline at end of file diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs index d38088adb..82927cffe 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService.cs @@ -4,7 +4,7 @@ public partial class EfGraphQLService : IEfGraphQLService where TDbContext : DbContext { - ResolveFilters? resolveFilters; + ResolveFilters? resolveFilters; bool disableTracking; bool disableAsync; ResolveDbContext resolveDbContext; @@ -14,7 +14,7 @@ public partial class EfGraphQLService : public EfGraphQLService( IModel model, ResolveDbContext resolveDbContext, - ResolveFilters? resolveFilters = null, + ResolveFilters? resolveFilters = null, bool disableTracking = false, bool disableAsync = false) { @@ -68,11 +68,8 @@ ResolveEfFieldContext BuildContext( public TDbContext ResolveDbContext(IResolveFieldContext context) => resolveDbContext(context.UserContext); - Filters ResolveFilter(IResolveFieldContext context) - { - var filter = resolveFilters?.Invoke(context.UserContext); - return filter ?? NullFilters.Instance; - } + Filters? ResolveFilter(IResolveFieldContext context) => + resolveFilters?.Invoke(context.UserContext); public IQueryable AddIncludes(IQueryable query, IResolveFieldContext context) where TItem : class => diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_First.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_First.cs index 4bbc5322d..40f580ed5 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_First.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_First.cs @@ -195,7 +195,8 @@ FieldType BuildFirstField( if (first is not null) { - if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, first)) + if (efFieldContext.Filters == null || + await efFieldContext.Filters.ShouldInclude(context.UserContext, efFieldContext.DbContext, context.User, first)) { if (mutate is not null) { diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs index bd8bd6696..5cf0daf1a 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs @@ -46,7 +46,8 @@ public FieldBuilder AddNavigationField( exception); } - if (await fieldContext.Filters.ShouldInclude(context.UserContext, context.User, result)) + if (fieldContext.Filters == null || + await fieldContext.Filters.ShouldInclude(context.UserContext, fieldContext.DbContext, context.User, result)) { return result; } diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs index 1dc9aa7d1..135ef051f 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs @@ -99,7 +99,11 @@ ConnectionBuilder AddEnumerableConnection( } enumerable = enumerable.ApplyGraphQlArguments(hasId, context, omitQueryArguments); - enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext, context.User); + if (efFieldContext.Filters != null) + { + enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext, efFieldContext.DbContext, context.User); + } + var page = enumerable.ToList(); return ConnectionConverter.ApplyConnectionContext( diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs index 3dbed079b..3856980e1 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs @@ -25,20 +25,24 @@ public FieldBuilder AddNavigationListField( if (resolve is not null) { - field.Resolver = new FuncFieldResolver>( - async context => + field.Resolver = new FuncFieldResolver>(async context => + { + var fieldContext = BuildContext(context); + var result = resolve(fieldContext); + + if (result is IQueryable) { - var fieldContext = BuildContext(context); - var result = resolve(fieldContext); + throw new("This API expects the resolver to return a IEnumerable, not an IQueryable. Instead use AddQueryField."); + } - if (result is IQueryable) - { - throw new("This API expects the resolver to return a IEnumerable, not an IQueryable. Instead use AddQueryField."); - } + result = result.ApplyGraphQlArguments(hasId, context, omitQueryArguments); + if (fieldContext.Filters == null) + { + return result; + } - result = result.ApplyGraphQlArguments(hasId, context, omitQueryArguments); - return await fieldContext.Filters.ApplyFilter(result, context.UserContext, context.User); - }); + return await fieldContext.Filters.ApplyFilter(result, context.UserContext, fieldContext.DbContext, context.User); + }); } graph.AddField(field); diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs index 7e9f1cb07..b67ded24e 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs @@ -144,7 +144,12 @@ FieldType BuildQueryField( exception); } - return await fieldContext.Filters.ApplyFilter(list, context.UserContext, context.User); + if (fieldContext.Filters == null) + { + return list; + } + + return await fieldContext.Filters.ApplyFilter(list, context.UserContext, fieldContext.DbContext, context.User); }); } diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_QueryableConnection.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_QueryableConnection.cs index f5299e296..c4ba3e4a9 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_QueryableConnection.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_QueryableConnection.cs @@ -156,7 +156,8 @@ ConnectionBuilder AddQueryableConnection( context.Before!, context, context.CancellationToken, - efFieldContext.Filters); + efFieldContext.Filters, + efFieldContext.DbContext); } catch (TaskCanceledException) { diff --git a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs index 4e6c9a266..50b894d62 100644 --- a/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs +++ b/src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs @@ -196,7 +196,8 @@ FieldType BuildSingleField( if (single is not null) { - if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, single)) + if (efFieldContext.Filters == null || + await efFieldContext.Filters.ShouldInclude(context.UserContext, efFieldContext.DbContext, context.User, single)) { if (mutate is not null) { diff --git a/src/GraphQL.EntityFramework/GraphApi/ResolveEfFieldContext.cs b/src/GraphQL.EntityFramework/GraphApi/ResolveEfFieldContext.cs index f16ba29e2..59000944d 100644 --- a/src/GraphQL.EntityFramework/GraphApi/ResolveEfFieldContext.cs +++ b/src/GraphQL.EntityFramework/GraphApi/ResolveEfFieldContext.cs @@ -5,5 +5,5 @@ public class ResolveEfFieldContext : where TDbContext : DbContext { public TDbContext DbContext { get; set; } = null!; - public Filters Filters { get; set; } = null!; + public Filters? Filters { get; set; } } \ No newline at end of file diff --git a/src/Snippets/GlobalFilterSnippets.cs b/src/Snippets/GlobalFilterSnippets.cs index 0eb5b1466..9c182513d 100644 --- a/src/Snippets/GlobalFilterSnippets.cs +++ b/src/Snippets/GlobalFilterSnippets.cs @@ -15,9 +15,9 @@ public static void Add(ServiceCollection services) { #region add-filter - var filters = new Filters(); + var filters = new Filters(); filters.Add( - (userContext, userPrincipal, item) => item.Property != "Ignore"); + (userContext, date, userPrincipal, item) => item.Property != "Ignore"); EfGraphQLConventions.RegisterInContainer( services, resolveFilters: _ => filters); diff --git a/src/Tests/ConnectionConverter/ConnectionConverterTests.cs b/src/Tests/ConnectionConverter/ConnectionConverterTests.cs index 586e760c3..28c2a2473 100644 --- a/src/Tests/ConnectionConverter/ConnectionConverterTests.cs +++ b/src/Tests/ConnectionConverter/ConnectionConverterTests.cs @@ -42,7 +42,7 @@ public async Task Queryable(int? first, int? after, int? last, int? before) var fieldContext = new ResolveFieldContext(); await using var database = await sqlInstance.Build(databaseSuffix: $"{first.GetValueOrDefault(0)}{after.GetValueOrDefault(0)}{last.GetValueOrDefault(0)}{before.GetValueOrDefault(0)}"); var entities = database.Context.Entities; - var connection = await ConnectionConverter.ApplyConnectionContext(entities.OrderBy(x=>x.Property), first, after, last, before, fieldContext, new()); + var connection = await ConnectionConverter.ApplyConnectionContext(entities.OrderBy(x=>x.Property), first, after, last, before, fieldContext, new(), Cancel.None,database.Context); await Verify(connection.Items!.OrderBy(_ => _!.Property)) .UseParameters(first, after, last, before); } diff --git a/src/Tests/GlobalFiltersTests.cs b/src/Tests/GlobalFiltersTests.cs index db7ae45f8..56b87074a 100644 --- a/src/Tests/GlobalFiltersTests.cs +++ b/src/Tests/GlobalFiltersTests.cs @@ -1,30 +1,30 @@ -using Filters = GraphQL.EntityFramework.Filters; - -public class GlobalFiltersTests +public class GlobalFiltersTests { [Fact] public async Task Simple() { - var filters= new Filters(); - filters.Add((_, _, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), null, new Target())); - Assert.False(await filters.ShouldInclude(new(), null, null)); - Assert.True(await filters.ShouldInclude(new(), null, new Target {Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), null, new Target {Property = "Ignore"})); - - filters.Add((_, _, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget())); - Assert.True(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), null, new ChildTarget {Property = "Ignore"})); - - filters.Add((_, _, target) => target.Property != "Ignore"); - Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget())); - Assert.True(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Include"})); - Assert.False(await filters.ShouldInclude(new(), null, new ImplementationTarget { Property = "Ignore" })); - - Assert.True(await filters.ShouldInclude(new(), null, new NonTarget { Property = "Foo" })); + var filters = new Filters(); + filters.Add((_, _, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), new(), null, new Target())); + Assert.False(await filters.ShouldInclude(new(), new(), null, null)); + Assert.True(await filters.ShouldInclude(new(), new(), null, new Target {Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), new(), null, new Target {Property = "Ignore"})); + + filters.Add((_, _, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), new(), null, new ChildTarget())); + Assert.True(await filters.ShouldInclude(new(), new(), null, new ChildTarget {Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), new(), null, new ChildTarget {Property = "Ignore"})); + + filters.Add((_, _, _, target) => target.Property != "Ignore"); + Assert.True(await filters.ShouldInclude(new(), new(), null, new ImplementationTarget())); + Assert.True(await filters.ShouldInclude(new(), new(), null, new ImplementationTarget { Property = "Include"})); + Assert.False(await filters.ShouldInclude(new(), new(), null, new ImplementationTarget { Property = "Ignore" })); + + Assert.True(await filters.ShouldInclude(new(), new(), null, new NonTarget { Property = "Foo" })); } + public class MyContext : DbContext; + public class NonTarget { public string? Property { get; set; } diff --git a/src/Tests/IntegrationTests/IntegrationTests.cs b/src/Tests/IntegrationTests/IntegrationTests.cs index c6858e723..99e755dd8 100644 --- a/src/Tests/IntegrationTests/IntegrationTests.cs +++ b/src/Tests/IntegrationTests/IntegrationTests.cs @@ -1,5 +1,3 @@ -using Filters = GraphQL.EntityFramework.Filters; - public partial class IntegrationTests { static SqlInstance sqlInstance; @@ -3008,7 +3006,7 @@ static async Task RunQuery( SqlDatabase database, string query, Inputs? inputs, - Filters? filters, + Filters? filters, bool disableTracking, object[] entities, bool disableAsync = false, diff --git a/src/Tests/IntegrationTests/IntegrationTests_filtered.cs b/src/Tests/IntegrationTests/IntegrationTests_filtered.cs index 0bd9f26c3..1717bc7c9 100644 --- a/src/Tests/IntegrationTests/IntegrationTests_filtered.cs +++ b/src/Tests/IntegrationTests/IntegrationTests_filtered.cs @@ -1,6 +1,4 @@ -using Filters = GraphQL.EntityFramework.Filters; - -public partial class IntegrationTests +public partial class IntegrationTests { [Fact] public async Task Child_filtered() @@ -39,11 +37,11 @@ public async Task Child_filtered() await RunQuery(database, query, null, BuildFilters(), false, [entity1, entity2, entity3]); } - static Filters BuildFilters() + static Filters BuildFilters() { - var filters = new Filters(); - filters.Add((_, _, item) => item.Property != "Ignore"); - filters.Add((_, _, item) => item.Property != "Ignore"); + var filters = new Filters(); + filters.Add((_, _, _, item) => item.Property != "Ignore"); + filters.Add((_, _, _, item) => item.Property != "Ignore"); return filters; } diff --git a/src/Tests/IntegrationTests/QueryExecutor.cs b/src/Tests/IntegrationTests/QueryExecutor.cs index 3c79e8447..6fbcc782b 100644 --- a/src/Tests/IntegrationTests/QueryExecutor.cs +++ b/src/Tests/IntegrationTests/QueryExecutor.cs @@ -5,7 +5,7 @@ public static async Task ExecuteQuery( ServiceCollection services, TDbContext data, Inputs? inputs, - Filters? filters, + Filters? filters, bool disableTracking, bool disableAsync) where TDbContext : DbContext