Skip to content

Commit 20201da

Browse files
authored
Return null when the type is nullable for Cosmos Max/Min/Average (#35173)
* Return null when the type is nullable for Cosmos Max/Min/Average Fixes #35094 This was a regression resulting from the major Cosmos query refactoring that happened in EF9. In EF8, the functions Min, Max, and Average would return null if the return type was nullable or was cast to a nullable when the collection is empty. In EF9, this started throwing, which is correct for non-nullable types, but a regression for nullable types. * Added notes
1 parent 1319ed4 commit 20201da

File tree

4 files changed

+203
-226
lines changed

4 files changed

+203
-226
lines changed

src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 34 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -443,25 +443,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
443443
/// doing so can result in application failures when updating to a new Entity Framework Core release.
444444
/// </summary>
445445
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
446-
{
447-
var selectExpression = (SelectExpression)source.QueryExpression;
448-
if (selectExpression.IsDistinct
449-
|| selectExpression.Limit != null
450-
|| selectExpression.Offset != null)
451-
{
452-
return null;
453-
}
454-
455-
if (selector != null)
456-
{
457-
source = TranslateSelect(source, selector);
458-
}
459-
460-
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
461-
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType));
462-
463-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
464-
}
446+
=> TranslateAggregate(source, selector, resultType, "AVG");
465447

466448
/// <summary>
467449
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -841,26 +823,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
841823
/// doing so can result in application failures when updating to a new Entity Framework Core release.
842824
/// </summary>
843825
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
844-
{
845-
var selectExpression = (SelectExpression)source.QueryExpression;
846-
if (selectExpression.IsDistinct
847-
|| selectExpression.Limit != null
848-
|| selectExpression.Offset != null)
849-
{
850-
return null;
851-
}
852-
853-
if (selector != null)
854-
{
855-
source = TranslateSelect(source, selector);
856-
}
857-
858-
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
859-
860-
projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
861-
862-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
863-
}
826+
=> TranslateAggregate(source, selector, resultType, "MAX");
864827

865828
/// <summary>
866829
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -869,26 +832,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
869832
/// doing so can result in application failures when updating to a new Entity Framework Core release.
870833
/// </summary>
871834
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
872-
{
873-
var selectExpression = (SelectExpression)source.QueryExpression;
874-
if (selectExpression.IsDistinct
875-
|| selectExpression.Limit != null
876-
|| selectExpression.Offset != null)
877-
{
878-
return null;
879-
}
880-
881-
if (selector != null)
882-
{
883-
source = TranslateSelect(source, selector);
884-
}
885-
886-
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
887-
888-
projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
889-
890-
return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
891-
}
835+
=> TranslateAggregate(source, selector, resultType, "MIN");
892836

893837
/// <summary>
894838
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
@@ -1241,7 +1185,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
12411185

12421186
projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping);
12431187

1244-
return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
1188+
return AggregateResultShaper(source, projection, resultType);
12451189
}
12461190

12471191
/// <summary>
@@ -1515,6 +1459,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
15151459

15161460
#endregion Queryable collection support
15171461

1462+
private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName)
1463+
{
1464+
var selectExpression = (SelectExpression)source.QueryExpression;
1465+
if (selectExpression.IsDistinct
1466+
|| selectExpression.Limit != null
1467+
|| selectExpression.Offset != null)
1468+
{
1469+
return null;
1470+
}
1471+
1472+
if (selector != null)
1473+
{
1474+
source = TranslateSelect(source, selector);
1475+
}
1476+
1477+
if (!_subquery && resultType.IsNullableType())
1478+
{
1479+
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
1480+
// Note that relational databases typically return null, which propagates. Cosmos will instead return no elements,
1481+
// and hence for Cosmos only we need to change no elements into null.
1482+
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
1483+
}
1484+
1485+
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
1486+
projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType));
1487+
1488+
return AggregateResultShaper(source, projection, resultType);
1489+
}
1490+
15181491
private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate)
15191492
{
15201493
var select = (SelectExpression)source.QueryExpression;
@@ -1695,7 +1668,6 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression,
16951668
private static ShapedQueryExpression AggregateResultShaper(
16961669
ShapedQueryExpression source,
16971670
Expression projection,
1698-
bool throwOnNullResult,
16991671
Type resultType)
17001672
{
17011673
var selectExpression = (SelectExpression)source.QueryExpression;
@@ -1706,29 +1678,7 @@ private static ShapedQueryExpression AggregateResultShaper(
17061678
var nullableResultType = resultType.MakeNullable();
17071679
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
17081680

1709-
if (throwOnNullResult)
1710-
{
1711-
var resultVariable = Expression.Variable(nullableResultType, "result");
1712-
var returnValueForNull = resultType.IsNullableType()
1713-
? (Expression)Expression.Constant(null, resultType)
1714-
: Expression.Throw(
1715-
Expression.New(
1716-
typeof(InvalidOperationException).GetConstructors()
1717-
.Single(ci => ci.GetParameters().Length == 1),
1718-
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
1719-
resultType);
1720-
1721-
shaper = Expression.Block(
1722-
new[] { resultVariable },
1723-
Expression.Assign(resultVariable, shaper),
1724-
Expression.Condition(
1725-
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
1726-
returnValueForNull,
1727-
resultType != resultVariable.Type
1728-
? Expression.Convert(resultVariable, resultType)
1729-
: resultVariable));
1730-
}
1731-
else if (resultType != shaper.Type)
1681+
if (resultType != shaper.Type)
17321682
{
17331683
shaper = Expression.Convert(shaper, resultType);
17341684
}

src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
518518
ShapedQueryExpression source,
519519
LambdaExpression? selector,
520520
Type resultType)
521-
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType);
521+
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType);
522522

523523
/// <inheritdoc />
524524
protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType)
@@ -971,7 +971,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
971971
}
972972

973973
return TranslateAggregateWithSelector(
974-
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
974+
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType);
975975
}
976976

977977
/// <inheritdoc />
@@ -990,7 +990,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
990990
}
991991

992992
return TranslateAggregateWithSelector(
993-
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
993+
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType);
994994
}
995995

996996
/// <inheritdoc />
@@ -1241,7 +1241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
12411241

12421242
/// <inheritdoc />
12431243
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
1244-
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType);
1244+
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType);
12451245

12461246
/// <inheritdoc />
12471247
protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count)
@@ -1966,7 +1966,6 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
19661966
ShapedQueryExpression source,
19671967
LambdaExpression? selectorLambda,
19681968
Func<Type, MethodInfo> methodGenerator,
1969-
bool throwWhenEmpty,
19701969
Type resultType)
19711970
{
19721971
var selectExpression = (SelectExpression)source.QueryExpression;
@@ -2012,48 +2011,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
20122011
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), translation } });
20132012

20142013
selectExpression.ClearOrdering();
2015-
Expression shaper;
2016-
2017-
if (throwWhenEmpty)
2018-
{
2019-
// Avg/Max/Min case.
2020-
// We always read nullable value
2021-
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
2022-
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
2023-
// otherwise, server would return null only if it is empty, and we throw
2024-
var nullableResultType = resultType.MakeNullable();
2025-
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
2026-
var resultVariable = Expression.Variable(nullableResultType, "result");
2027-
var returnValueForNull = resultType.IsNullableType()
2028-
? (Expression)Expression.Default(resultType)
2029-
: translation.Type.IsNullableType()
2030-
? Expression.Default(resultType)
2031-
: Expression.Throw(
2032-
Expression.New(
2033-
typeof(InvalidOperationException).GetConstructors()
2034-
.Single(ci => ci.GetParameters().Length == 1),
2035-
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
2036-
resultType);
2037-
2038-
shaper = Expression.Block(
2039-
new[] { resultVariable },
2040-
Expression.Assign(resultVariable, shaper),
2041-
Expression.Condition(
2042-
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
2043-
returnValueForNull,
2044-
resultType != resultVariable.Type
2045-
? Expression.Convert(resultVariable, resultType)
2046-
: resultVariable));
2047-
}
2048-
else
2049-
{
2050-
// Sum case. Projection is always non-null. We read nullable value.
2051-
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());
20522014

2053-
if (resultType != shaper.Type)
2054-
{
2055-
shaper = Expression.Convert(shaper, resultType);
2056-
}
2015+
// Sum case. Projection is always non-null. We read nullable value.
2016+
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());
2017+
2018+
if (resultType != shaper.Type)
2019+
{
2020+
shaper = Expression.Convert(shaper, resultType);
20572021
}
20582022

20592023
return source.UpdateShaperExpression(shaper);

test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.ComponentModel.DataAnnotations.Schema;
5+
46
namespace Microsoft.EntityFrameworkCore.Query;
57

68
#nullable disable
@@ -50,6 +52,115 @@ public enum MemberType
5052

5153
#endregion 34911
5254

55+
#region 35094
56+
57+
// TODO: Move these tests to a better location. They require nullable properties with nulls in the database.
58+
59+
[ConditionalFact]
60+
public virtual async Task Min_over_value_type_containing_nulls()
61+
{
62+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
63+
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableVal));
64+
}
65+
66+
[ConditionalFact]
67+
public virtual async Task Min_over_value_type_containing_all_nulls()
68+
{
69+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
70+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal));
71+
}
72+
73+
[ConditionalFact]
74+
public virtual async Task Min_over_reference_type_containing_nulls()
75+
{
76+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
77+
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableRef));
78+
}
79+
80+
[ConditionalFact]
81+
public virtual async Task Min_over_reference_type_containing_all_nulls()
82+
{
83+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
84+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef));
85+
}
86+
87+
[ConditionalFact]
88+
public virtual async Task Min_over_reference_type_containing_no_data()
89+
{
90+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
91+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MinAsync(p => p.NullableRef));
92+
}
93+
94+
[ConditionalFact]
95+
public virtual async Task Max_over_value_type_containing_nulls()
96+
{
97+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
98+
Assert.Equal(3.14, await context.Set<Context35094.Product>().MaxAsync(p => p.NullableVal));
99+
}
100+
101+
[ConditionalFact]
102+
public virtual async Task Max_over_value_type_containing_all_nulls()
103+
{
104+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
105+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal));
106+
}
107+
108+
[ConditionalFact]
109+
public virtual async Task Max_over_reference_type_containing_nulls()
110+
{
111+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
112+
Assert.Equal("Value", await context.Set<Context35094.Product>().MaxAsync(p => p.NullableRef));
113+
}
114+
115+
[ConditionalFact]
116+
public virtual async Task Max_over_reference_type_containing_all_nulls()
117+
{
118+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
119+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef));
120+
}
121+
122+
[ConditionalFact]
123+
public virtual async Task Max_over_reference_type_containing_no_data()
124+
{
125+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
126+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef));
127+
}
128+
129+
[ConditionalFact]
130+
public virtual async Task Average_over_value_type_containing_nulls()
131+
{
132+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
133+
Assert.Null(await context.Set<Context35094.Product>().AverageAsync(p => p.NullableVal));
134+
}
135+
136+
[ConditionalFact]
137+
public virtual async Task Average_over_value_type_containing_all_nulls()
138+
{
139+
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
140+
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal));
141+
}
142+
143+
protected class Context35094(DbContextOptions options) : DbContext(options)
144+
{
145+
public DbSet<Product> Products { get; set; }
146+
147+
protected override void OnModelCreating(ModelBuilder modelBuilder)
148+
=> modelBuilder.Entity<Product>().HasData(
149+
new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 },
150+
new Product { Id = 2, NullableVal = 3.14 },
151+
new Product { Id = 3, NullableRef = "Value" });
152+
153+
public class Product
154+
{
155+
[DatabaseGenerated(DatabaseGeneratedOption.None)]
156+
public int Id { get; set; }
157+
public double? NullableVal { get; set; }
158+
public string NullableRef { get; set; }
159+
}
160+
}
161+
162+
#endregion 35094
163+
53164
protected override string StoreName
54165
=> "AdHocMiscellaneousQueryTests";
55166

0 commit comments

Comments
 (0)