-
Notifications
You must be signed in to change notification settings - Fork 256
Expand file tree
/
Copy pathIncludeEvaluator.cs
More file actions
124 lines (103 loc) · 6.35 KB
/
IncludeEvaluator.cs
File metadata and controls
124 lines (103 loc) · 6.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
using Microsoft.EntityFrameworkCore.Query;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Reflection;
namespace Ardalis.Specification.EntityFrameworkCore;
public class IncludeEvaluator : IEvaluator
{
private static readonly MethodInfo _includeMethodInfo = typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 2
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IQueryable<>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));
private static readonly MethodInfo _thenIncludeAfterReferenceMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3
&& mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));
private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3
&& !mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));
private readonly record struct CacheKey(Type EntityType, Type PropertyType, Type? PreviousReturnType);
private static readonly ConcurrentDictionary<CacheKey, Func<IQueryable, LambdaExpression, IQueryable>> _cache = new();
private IncludeEvaluator() { }
public static IncludeEvaluator Instance = new();
public bool IsCriteriaEvaluator => false;
/// <inheritdoc/>
public IQueryable<T> GetQuery<T>(IQueryable<T> query, ISpecification<T> specification) where T : class
{
if (specification is Specification<T> spec)
{
if (spec.OneOrManyIncludeExpressions.IsEmpty) return query;
if (spec.OneOrManyIncludeExpressions.SingleOrDefault is { } includeExpression)
{
var lambdaExpr = includeExpression.LambdaExpression;
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, null);
var include = _cache.GetOrAdd(key, CreateIncludeDelegate);
return (IQueryable<T>)include(query, lambdaExpr);
}
}
foreach (var includeExpression in specification.IncludeExpressions)
{
var lambdaExpr = includeExpression.LambdaExpression;
if (includeExpression.Type == IncludeTypeEnum.Include)
{
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, null);
var include = _cache.GetOrAdd(key, CreateIncludeDelegate);
query = (IQueryable<T>)include(query, lambdaExpr);
}
else if (includeExpression.Type == IncludeTypeEnum.ThenInclude)
{
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, includeExpression.PreviousPropertyType);
var include = _cache.GetOrAdd(key, CreateThenIncludeDelegate);
query = (IQueryable<T>)include(query, lambdaExpr);
}
}
return query;
}
private static Func<IQueryable, LambdaExpression, IQueryable> CreateIncludeDelegate(CacheKey cacheKey)
{
var includeMethod = _includeMethodInfo.MakeGenericMethod(cacheKey.EntityType, cacheKey.PropertyType);
var sourceParameter = Expression.Parameter(typeof(IQueryable));
var selectorParameter = Expression.Parameter(typeof(LambdaExpression));
var call = Expression.Call(
includeMethod,
Expression.Convert(sourceParameter, typeof(IQueryable<>).MakeGenericType(cacheKey.EntityType)),
Expression.Convert(selectorParameter, typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(cacheKey.EntityType, cacheKey.PropertyType))));
var lambda = Expression.Lambda<Func<IQueryable, LambdaExpression, IQueryable>>(call, sourceParameter, selectorParameter);
return lambda.Compile();
}
private static Func<IQueryable, LambdaExpression, IQueryable> CreateThenIncludeDelegate(CacheKey cacheKey)
{
Debug.Assert(cacheKey.PreviousReturnType is not null);
var thenIncludeInfo = IsGenericEnumerable(cacheKey.PreviousReturnType, out var previousPropertyType)
? _thenIncludeAfterEnumerableMethodInfo
: _thenIncludeAfterReferenceMethodInfo;
var thenIncludeMethod = thenIncludeInfo.MakeGenericMethod(cacheKey.EntityType, previousPropertyType, cacheKey.PropertyType);
var sourceParameter = Expression.Parameter(typeof(IQueryable));
var selectorParameter = Expression.Parameter(typeof(LambdaExpression));
var call = Expression.Call(
thenIncludeMethod,
// We must pass cacheKey.PreviousReturnType. It must be exact type, not the generic type argument
Expression.Convert(sourceParameter, typeof(IIncludableQueryable<,>).MakeGenericType(cacheKey.EntityType, cacheKey.PreviousReturnType)),
Expression.Convert(selectorParameter, typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(previousPropertyType, cacheKey.PropertyType))));
var lambda = Expression.Lambda<Func<IQueryable, LambdaExpression, IQueryable>>(call, sourceParameter, selectorParameter);
return lambda.Compile();
}
private static bool IsGenericEnumerable(Type type, out Type propertyType)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
propertyType = type.GenericTypeArguments[0];
return true;
}
propertyType = type;
return false;
}
}