Skip to content

Commit e5982ae

Browse files
authored
Fix batch paging translation errors with inheritance selectors (#8938)
1 parent f425d23 commit e5982ae

File tree

29 files changed

+169
-589
lines changed

29 files changed

+169
-589
lines changed

src/GreenDonut/src/GreenDonut.Data.EntityFramework/Expressions/ExpressionHelpers.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ public static (Expression<Func<T, bool>> WhereExpression, int Offset) BuildWhere
122122
/// <param name="forward">
123123
/// Defines how the dataset is sorted.
124124
/// </param>
125+
/// <param name="selector">
126+
/// Optional selector to apply to each item before materialization.
127+
/// </param>
125128
/// <param name="requestedCount">
126129
/// The number of items that are requested.
127130
/// </param>
@@ -141,6 +144,7 @@ public static BatchExpression<TK, TV> BuildBatchExpression<TK, TV>(
141144
ReadOnlySpan<LambdaExpression> orderExpressions,
142145
ReadOnlySpan<string> orderMethods,
143146
bool forward,
147+
Expression<Func<TV, TV>>? selector,
144148
ref int requestedCount)
145149
{
146150
if (keys.Length == 0)
@@ -177,6 +181,17 @@ public static BatchExpression<TK, TV> BuildBatchExpression<TK, TV>(
177181
typedOrderExpression);
178182
}
179183

184+
// apply the selector to each item in the grouping after ordering
185+
if (selector is not null)
186+
{
187+
var selectMethod = typeof(Enumerable)
188+
.GetMethods(BindingFlags.Static | BindingFlags.Public)
189+
.First(m => m.Name == nameof(Enumerable.Select) && m.GetParameters().Length == 2)
190+
.MakeGenericMethod(typeof(TV), typeof(TV));
191+
192+
source = Expression.Call(selectMethod, source, selector);
193+
}
194+
180195
var offset = 0;
181196
var usesRelativeCursors = false;
182197
Cursor? cursor = null;

src/GreenDonut/src/GreenDonut.Data.EntityFramework/Expressions/QueryHelpers.cs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,21 @@ static List<MemberExpression> GetMemberExpressions(Expression<Func<T, TKey>> key
7474
}
7575
}
7676

77-
private static Expression<Func<T, T>>? ExtractCurrentSelector<T>(
77+
public static Expression<Func<T, T>>? ExtractCurrentSelector<T>(
7878
IQueryable<T> query)
7979
{
8080
var visitor = new ExtractSelectExpressionVisitor();
8181
visitor.Visit(query.Expression);
8282
return visitor.Selector as Expression<Func<T, T>>;
8383
}
8484

85+
public static IQueryable<T> RemoveSelector<T>(IQueryable<T> query)
86+
{
87+
var visitor = new RemoveSelectorVisitor();
88+
var newExpression = visitor.Visit(query.Expression);
89+
return query.Provider.CreateQuery<T>(newExpression);
90+
}
91+
8592
private static Expression<Func<T, T>> AddPropertiesInSelector<T>(
8693
Expression<Func<T, T>> selector,
8794
List<MemberExpression> properties)
@@ -109,6 +116,51 @@ private static IQueryable<T> ReplaceSelector<T>(
109116
return query.Provider.CreateQuery<T>(newExpression);
110117
}
111118

119+
private sealed class RemoveSelectorVisitor : ExpressionVisitor
120+
{
121+
private const string SelectMethod = "Select";
122+
123+
protected override Expression VisitMethodCall(MethodCallExpression node)
124+
{
125+
if (node.Method.Name != SelectMethod || node.Arguments.Count != 2)
126+
{
127+
return base.VisitMethodCall(node);
128+
}
129+
130+
var lambda = ConvertToLambda(node.Arguments[1]);
131+
if (!lambda.Type.IsGenericType || lambda.Type.GetGenericTypeDefinition() != typeof(Func<,>))
132+
{
133+
return base.VisitMethodCall(node);
134+
}
135+
136+
var genericArgs = lambda.Type.GetGenericArguments();
137+
// remove selectors of type Func<T, T>
138+
if (genericArgs[0] == genericArgs[1])
139+
{
140+
// return the source expression, effectively removing the Select
141+
return Visit(node.Arguments[0]);
142+
}
143+
144+
return base.VisitMethodCall(node);
145+
}
146+
147+
private static LambdaExpression ConvertToLambda(Expression e)
148+
{
149+
while (e.NodeType == ExpressionType.Quote)
150+
{
151+
e = ((UnaryExpression)e).Operand;
152+
}
153+
154+
if (e.NodeType != ExpressionType.MemberAccess)
155+
{
156+
return (LambdaExpression)e;
157+
}
158+
159+
var typeArguments = e.Type.GetGenericArguments()[0].GetGenericArguments();
160+
return Expression.Lambda(e, Expression.Parameter(typeArguments[0]));
161+
}
162+
}
163+
112164
public class AddPropertiesVisitorRewriter : ExpressionVisitor
113165
{
114166
private readonly List<MemberExpression> _propertiesToAdd;
@@ -122,6 +174,38 @@ public AddPropertiesVisitorRewriter(
122174
_parameter = parameter;
123175
}
124176

177+
protected override Expression VisitConditional(ConditionalExpression node)
178+
{
179+
// recursively visit conditional branches to handle type switches in inheritance scenarios
180+
var test = Visit(node.Test);
181+
var ifTrue = Visit(node.IfTrue);
182+
var ifFalse = Visit(node.IfFalse);
183+
184+
if (test != node.Test || ifTrue != node.IfTrue || ifFalse != node.IfFalse)
185+
{
186+
return Expression.Condition(test, ifTrue, ifFalse, node.Type);
187+
}
188+
189+
return node;
190+
}
191+
192+
protected override Expression VisitUnary(UnaryExpression node)
193+
{
194+
// handle convert expressions that wrap member init expressions in inheritance scenarios
195+
if (node.NodeType != ExpressionType.Convert)
196+
{
197+
return base.VisitUnary(node);
198+
}
199+
200+
var operand = Visit(node.Operand);
201+
if (operand != node.Operand)
202+
{
203+
return Expression.Convert(operand, node.Type);
204+
}
205+
206+
return base.VisitUnary(node);
207+
}
208+
125209
protected override Expression VisitMemberInit(MemberInitExpression node)
126210
{
127211
// Get existing bindings (properties in the current selector)

src/GreenDonut/src/GreenDonut.Data.EntityFramework/Extensions/PagingQueryableExtensions.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,18 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
436436
}
437437

438438
source = QueryHelpers.EnsureOrderPropsAreSelected(source);
439+
440+
// extract the selector before ensuring group props are selected,
441+
// as we need to remove it before grouping and re-apply it after
442+
var selector = QueryHelpers.ExtractCurrentSelector(source);
443+
444+
// if we have a selector, remove it before grouping
445+
// we'll re-apply it to the grouped items later
446+
if (selector is not null)
447+
{
448+
source = QueryHelpers.RemoveSelector(source);
449+
}
450+
439451
source = QueryHelpers.EnsureGroupPropsAreSelected(source, keySelector);
440452

441453
// we need to move the ordering into the select expression we are constructing
@@ -458,6 +470,7 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
458470
ordering.OrderExpressions,
459471
ordering.OrderMethods,
460472
forward,
473+
selector,
461474
ref requestedCount);
462475
var map = new Dictionary<TKey, Page<TValue>>();
463476

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_NET10_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @p
5959

6060
```sql
6161
-- @brandIds={ '11', '13' } (DbType = Object)
62-
SELECT p1."BrandId", p3."Id", p3."Name", p3."BrandId"
62+
SELECT p1."BrandId", p3."Id", p3."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_NET8_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @__p_0
5959

6060
```sql
6161
-- @__brandIds_0={ '11', '13' } (DbType = Object)
62-
SELECT t."BrandId", t0."Id", t0."Name", t0."BrandId"
62+
SELECT t."BrandId", t0."Id", t0."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_NET9_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @__p_0
5959

6060
```sql
6161
-- @__brandIds_0={ '11', '13' } (DbType = Object)
62-
SELECT p1."BrandId", p3."Id", p3."Name", p3."BrandId"
62+
SELECT p1."BrandId", p3."Id", p3."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_Name_Desc_NET10_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @p
5959

6060
```sql
6161
-- @brandIds={ '11', '13' } (DbType = Object)
62-
SELECT p1."BrandId", p3."Id", p3."Name", p3."BrandId"
62+
SELECT p1."BrandId", p3."Id", p3."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_Name_Desc_NET8_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @__p_0
5959

6060
```sql
6161
-- @__brandIds_0={ '11', '13' } (DbType = Object)
62-
SELECT t."BrandId", t0."Id", t0."Name", t0."BrandId"
62+
SELECT t."BrandId", t0."Id", t0."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_And_Products_First_2_Name_Desc_NET9_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LIMIT @__p_0
5959

6060
```sql
6161
-- @__brandIds_0={ '11', '13' } (DbType = Object)
62-
SELECT p1."BrandId", p3."Id", p3."Name", p3."BrandId"
62+
SELECT p1."BrandId", p3."Id", p3."Name"
6363
FROM (
6464
SELECT p."BrandId"
6565
FROM "Products" AS p

src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/IntegrationTests.Query_Brands_First_2_Products_First_2_ForwardCursors_NET10_0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ GROUP BY p."BrandId"
6262

6363
```sql
6464
-- @brandIds={ '11' } (DbType = Object)
65-
SELECT p1."BrandId", p3."Name", p3."Id", p3."BrandId"
65+
SELECT p1."BrandId", p3."Name", p3."Id"
6666
FROM (
6767
SELECT p."BrandId"
6868
FROM "Products" AS p

0 commit comments

Comments
 (0)