Skip to content

Commit b458c71

Browse files
authored
Merge pull request #72 from rhodon-jargon/interface-support
Fix interface property lookup in generic method
2 parents d6b1cfb + f806c1c commit b458c71

File tree

5 files changed

+74
-35
lines changed

5 files changed

+74
-35
lines changed

src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,8 @@ private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo me
5454
return true;
5555
}
5656

57-
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
58-
{
59-
if (allDerivedMethods is { Length: > 0 })
60-
{
61-
var baseDefinition = methodInfo.GetBaseDefinition();
62-
for (var i = 0; i < allDerivedMethods.Length; i++)
63-
{
64-
var derivedMethodInfo = allDerivedMethods[i];
65-
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
66-
{
67-
return i;
68-
}
69-
}
70-
}
71-
72-
return null;
73-
}
57+
private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition)
58+
=> methodInfo.GetBaseDefinition() == baseDefinition;
7459

7560
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
7661
{
@@ -81,31 +66,38 @@ public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo m
8166

8267
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
8368

84-
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
85-
? derivedMethods[i]
86-
// No derived methods were found. Return the original methodInfo
87-
: methodInfo;
69+
MethodInfo? overridingMethod = null;
70+
if (derivedMethods is { Length: > 0 })
71+
{
72+
var baseDefinition = methodInfo.GetBaseDefinition();
73+
overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo
74+
=> derivedMethodInfo.IsOverridingMethodOf(baseDefinition));
75+
}
76+
77+
return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo
8878
}
8979

9080
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
9181
{
92-
var accessor = propertyInfo.GetAccessors(true)[0];
93-
94-
if (!derivedType.CanHaveOverridingMethod(accessor))
82+
var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod);
83+
if (accessor is null)
9584
{
9685
return propertyInfo;
9786
}
87+
88+
var isGetAccessor = propertyInfo.GetMethod == accessor;
9889

9990
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
100-
var derivedPropertyMethods = derivedProperties
101-
.Select((Func<PropertyInfo, MethodInfo?>)
102-
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
103-
.OfType<MethodInfo>().ToArray();
104-
105-
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
106-
? derivedProperties[i]
107-
// No derived methods were found. Return the original methodInfo
108-
: propertyInfo;
91+
92+
PropertyInfo? overridingProperty = null;
93+
if (derivedProperties is { Length: > 0 })
94+
{
95+
var baseDefinition = accessor.GetBaseDefinition();
96+
overridingProperty = derivedProperties.FirstOrDefault(p
97+
=> (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true);
98+
}
99+
100+
return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo
109101
}
110102

111103
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)

src/EntityFrameworkCore.Projectables/Infrastructure/Internal/ProjectionOptionsExtension.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
5454
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryCompiler));
5555
if (targetDescriptor is null)
5656
{
57-
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first"); ;
57+
throw new InvalidOperationException("No QueryProvider is configured yet. Please make sure to configure a database provider first");
5858
}
5959

6060
var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryCompiler), new[] { targetDescriptor.ServiceType });
@@ -70,7 +70,7 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
7070
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryTranslationPreprocessorFactory));
7171
if (targetDescriptor is null)
7272
{
73-
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first"); ;
73+
throw new InvalidOperationException("No QueryTranslationPreprocessorFactory is configured yet. Please make sure to configure a database provider first");
7474
}
7575

7676
var decoratorObjectFactory = ActivatorUtilities.CreateFactory(typeof(CustomQueryTranslationPreprocessorFactory), new[] { targetDescriptor.ServiceType });
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
SELECT 4
2+
FROM [Concrete] AS [c]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT [c].[Id]
2+
FROM [BaseProvider] AS [b]
3+
INNER JOIN [Concrete] AS [c] ON [b].[Id] = [c].[BaseProviderId]

tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,20 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
2020
[UsesVerify]
2121
public class InheritedModelTests
2222
{
23+
public interface IBaseProvider<TBase>
24+
{
25+
ICollection<TBase> Bases { get; set; }
26+
}
27+
28+
public class BaseProvider : IBaseProvider<Concrete>
29+
{
30+
public int Id { get; set; }
31+
public ICollection<Concrete> Bases { get; set; }
32+
}
33+
2334
public interface IBase
2435
{
36+
int Id { get; }
2537
int ComputedProperty { get; }
2638
int ComputedMethod();
2739
}
@@ -117,6 +129,26 @@ public Task ProjectOverImplementedMethod()
117129

118130
return Verifier.Verify(query.ToQueryString());
119131
}
132+
133+
[Fact]
134+
public Task ProjectOverProvider()
135+
{
136+
using var dbContext = new SampleDbContext<BaseProvider>();
137+
138+
var query = dbContext.Set<BaseProvider>().AllBases<BaseProvider, Concrete>();
139+
140+
return Verifier.Verify(query.ToQueryString());
141+
}
142+
143+
[Fact]
144+
public Task ProjectOverExtensionMethod()
145+
{
146+
using var dbContext = new SampleDbContext<Concrete>();
147+
148+
var query = dbContext.Set<Concrete>().Select(c => c.ComputedPropertyPlusMethod());
149+
150+
return Verifier.Verify(query.ToQueryString());
151+
}
120152
}
121153

122154
public static class ModelExtensions
@@ -128,5 +160,15 @@ public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<
128160
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
129161
where TConcrete : InheritedModelTests.IBase
130162
=> concretes.Select(x => x.ComputedMethod());
163+
164+
public static IQueryable<int> AllBases<TProvider, TBase>(this IQueryable<TProvider> concretes)
165+
where TProvider : InheritedModelTests.IBaseProvider<TBase>
166+
where TBase : InheritedModelTests.IBase
167+
=> concretes.SelectMany(x => x.Bases).Select(x => x.Id);
168+
169+
[Projectable]
170+
public static int ComputedPropertyPlusMethod<TConcrete>(this TConcrete concrete)
171+
where TConcrete : InheritedModelTests.IBase
172+
=> concrete.ComputedProperty + concrete.ComputedMethod();
131173
}
132174
}

0 commit comments

Comments
 (0)