Skip to content

Commit 22204b5

Browse files
authored
CSHARP-5675: Where possible, return null for average over the empty set (#1755)
* CSHARP-5675: Where possible, return null for average over the empty set Fixes [CSHARP-5675: Average over empty collection of nullable types throws](https://jira.mongodb.org/browse/CSHARP-5675) For the terminating operator, the fix is to use the SingleOrDefault finalizer when the type is nullable. * Review feedback.
1 parent 457fdd0 commit 22204b5

File tree

4 files changed

+243
-11
lines changed

4 files changed

+243
-11
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ internal static class AverageMethodToExecutableQueryTranslator<TOutput>
3636
// private static fields
3737
private static readonly MethodInfo[] __averageMethods;
3838
private static readonly MethodInfo[] __averageWithSelectorMethods;
39-
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __finalizer = new SingleFinalizer<TOutput>();
39+
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __singleFinalizer = new SingleFinalizer<TOutput>();
40+
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer<TOutput>();
4041

4142
// static constructor
4243
static AverageMethodToExecutableQueryTranslator()
@@ -138,11 +139,11 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
138139

139140
IBsonSerializer outputValueSerializer = expression.GetResultType() switch
140141
{
141-
Type t when t == typeof(int) => new Int32Serializer(),
142-
Type t when t == typeof(long) => new Int64Serializer(),
143-
Type t when t == typeof(float) => new SingleSerializer(),
144-
Type t when t == typeof(double) => new DoubleSerializer(),
145-
Type t when t == typeof(decimal) => new DecimalSerializer(),
142+
Type t when t == typeof(int) => Int32Serializer.Instance,
143+
Type t when t == typeof(long) => Int64Serializer.Instance,
144+
Type t when t == typeof(float) => SingleSerializer.Instance,
145+
Type t when t == typeof(double) => DoubleSerializer.Instance,
146+
Type t when t == typeof(decimal) => DecimalSerializer.Instance,
146147
Type { IsConstructedGenericType: true } t when t.GetGenericTypeDefinition() == typeof(Nullable<>) => (IBsonSerializer)Activator.CreateInstance(typeof(NullableSerializer<>).MakeGenericType(t.GenericTypeArguments[0])),
147148
_ => throw new ExpressionNotSupportedException(expression)
148149
};
@@ -155,10 +156,14 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
155156
AstStage.Project(AstProject.ExcludeId()),
156157
outputWrappedValueSerializer);
157158

159+
var returnType = expression.Type;
160+
158161
return ExecutableQuery.Create(
159162
provider,
160163
pipeline,
161-
__finalizer);
164+
returnType.IsNullable() // Note: numeric types are never reference types
165+
? __singleOrDefaultFinalizer
166+
: __singleFinalizer);
162167
}
163168

164169
throw new ExpressionNotSupportedException(expression);

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AverageMethodToAggregationExpressionTranslatorTests.cs

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
1718
using System.Linq;
1819
using FluentAssertions;
@@ -391,6 +392,112 @@ public void Average_with_nullable_longs_selector_should_work(
391392
results.Should().Equal(null, null, 4.0);
392393
}
393394

395+
[Theory]
396+
[ParameterAttributeData]
397+
public void Average_over_empty_set_of_nullable_values_should_work(
398+
[Values(false, true)] bool withNestedAsQueryable)
399+
{
400+
var collection = Fixture.Collection;
401+
402+
var queryable = withNestedAsQueryable ?
403+
collection.AsQueryable().Select(x => x.EmptyNullableDecimals.AsQueryable().Average()) :
404+
collection.AsQueryable().Select(x => x.EmptyNullableDecimals.Average());
405+
406+
var stages = Translate(collection, queryable);
407+
AssertStages(stages, "{ $project : { _v : { $avg : '$EmptyNullableDecimals' }, _id : 0 } }");
408+
409+
var results = queryable.ToList();
410+
results.Should().Equal(null, null, null);
411+
}
412+
413+
[Theory]
414+
[ParameterAttributeData]
415+
public void Average_with_selector_over_empty_set_of_nullable_values_should_work(
416+
[Values(false, true)] bool withNestedAsQueryable)
417+
{
418+
var collection = Fixture.Collection;
419+
420+
var queryable = withNestedAsQueryable ?
421+
collection.AsQueryable().Select(x => x.EmptyNullableDecimals.AsQueryable().Average(x => x * 2.0M)) :
422+
collection.AsQueryable().Select(x => x.EmptyNullableDecimals.Average(x => x * 2.0M));
423+
424+
var stages = Translate(collection, queryable);
425+
AssertStages(stages, "{ $project : { _v : { $avg : { $map : { input : '$EmptyNullableDecimals', as : 'x', in : { $multiply : ['$$x', NumberDecimal(2)] } } } }, _id : 0 } }");
426+
427+
var results = queryable.ToList();
428+
results.Should().Equal(null, null, null);
429+
}
430+
431+
[Theory]
432+
[ParameterAttributeData]
433+
public void Average_over_empty_set_of_non_nullable_values_should_throw(
434+
[Values(false, true)] bool withNestedAsQueryable)
435+
{
436+
var collection = Fixture.Collection;
437+
438+
var queryable = withNestedAsQueryable ?
439+
collection.AsQueryable().Select(x => x.EmptyDecimals.AsQueryable().Average()) :
440+
collection.AsQueryable().Select(x => x.EmptyDecimals.Average());
441+
442+
var stages = Translate(collection, queryable);
443+
AssertStages(stages, "{ $project : { _v : { $avg : '$EmptyDecimals' }, _id : 0 } }");
444+
445+
Assert.Throws<FormatException>(() => queryable.ToList());
446+
}
447+
448+
[Theory]
449+
[ParameterAttributeData]
450+
public void Average_with_selector_over_empty_set_of_non_nullable_values_should_throw(
451+
[Values(false, true)] bool withNestedAsQueryable)
452+
{
453+
var collection = Fixture.Collection;
454+
455+
var queryable = withNestedAsQueryable ?
456+
collection.AsQueryable().Select(x => x.EmptyDecimals.AsQueryable().Average(x => x * 2.0M)) :
457+
collection.AsQueryable().Select(x => x.EmptyDecimals.Average(x => x * 2.0M));
458+
459+
var stages = Translate(collection, queryable);
460+
AssertStages(stages, "{ $project : { _v : { $avg : { $map : { input : '$EmptyDecimals', as : 'x', in : { $multiply : ['$$x', NumberDecimal(2)] } } } }, _id : 0 } }");
461+
462+
Assert.Throws<FormatException>(() => queryable.ToList());
463+
}
464+
465+
[Theory]
466+
[ParameterAttributeData]
467+
public void Average_over_empty_set_of_non_nullable_values_cast_to_nullable_should_work(
468+
[Values(false, true)] bool withNestedAsQueryable)
469+
{
470+
var collection = Fixture.Collection;
471+
472+
var queryable = withNestedAsQueryable ?
473+
collection.AsQueryable().Select(x => x.EmptyDecimals.Select(e => (decimal?)e).AsQueryable().Average()) :
474+
collection.AsQueryable().Select(x => x.EmptyDecimals.Select(e => (decimal?)e).Average());
475+
476+
var stages = Translate(collection, queryable);
477+
AssertStages(stages, "{ $project : { _v : { $avg : { $map : { input : '$EmptyDecimals', as : 'e', in : '$$e' } } }, _id : 0 } }");
478+
479+
var results = queryable.ToList();
480+
results.Should().Equal(null, null, null);
481+
}
482+
483+
[Theory]
484+
[ParameterAttributeData]
485+
public void Average_with_selector_over_empty_set_of_non_nullable_values_cast_to_nullable_should_work(
486+
[Values(false, true)] bool withNestedAsQueryable)
487+
{
488+
var collection = Fixture.Collection;
489+
490+
var queryable = withNestedAsQueryable ?
491+
collection.AsQueryable().Select(x => x.EmptyDecimals.Select(e => (decimal?)e).AsQueryable().Average(x => x * 2.0M)) :
492+
collection.AsQueryable().Select(x => x.EmptyDecimals.Select(e => (decimal?)e).Average(x => x * 2.0M));
493+
494+
var stages = Translate(collection, queryable);
495+
AssertStages(stages, "{ $project : { _v : { $avg : { $map : { input : { $map : { input : '$EmptyDecimals', as : 'e', in : '$$e' } }, as : 'x', in : { $multiply : ['$$x', { '$numberDecimal' : '2.0' }] } } } }, _id : 0 } }");
496+
497+
var results = queryable.ToList();
498+
results.Should().Equal(null, null, null);
499+
}
500+
394501
public class C
395502
{
396503
public int Id { get; set; }
@@ -404,6 +511,8 @@ public class C
404511
public float?[] NullableFloats { get; set; }
405512
public int?[] NullableInts { get; set; }
406513
public long?[] NullableLongs { get; set; }
514+
[BsonRepresentation(BsonType.Decimal128)] public decimal[] EmptyDecimals { get; set; }
515+
[BsonRepresentation(BsonType.Decimal128)] public decimal?[] EmptyNullableDecimals { get; set; }
407516
}
408517

409518
public sealed class ClassFixture : MongoCollectionFixture<C>
@@ -422,7 +531,9 @@ public sealed class ClassFixture : MongoCollectionFixture<C>
422531
NullableDoubles = new double?[0] { },
423532
NullableFloats = new float?[0] { },
424533
NullableInts = new int?[0] { },
425-
NullableLongs = new long?[0] { }
534+
NullableLongs = new long?[0] { },
535+
EmptyDecimals = [],
536+
EmptyNullableDecimals = []
426537
},
427538
new C
428539
{
@@ -436,7 +547,9 @@ public sealed class ClassFixture : MongoCollectionFixture<C>
436547
NullableDoubles = new double?[] { null },
437548
NullableFloats = new float?[] { null },
438549
NullableInts = new int?[] { null },
439-
NullableLongs = new long?[] { null }
550+
NullableLongs = new long?[] { null },
551+
EmptyDecimals = [],
552+
EmptyNullableDecimals = []
440553
},
441554
new C
442555
{
@@ -450,7 +563,9 @@ public sealed class ClassFixture : MongoCollectionFixture<C>
450563
NullableDoubles = new double?[] { null, 1.0, 2.0, 3.0 },
451564
NullableFloats = new float?[] { null, 1.0F, 2.0F, 3.0F },
452565
NullableInts = new int?[] { null, 1, 2, 3 },
453-
NullableLongs = new long?[] { null, 1L, 2L, 3L }
566+
NullableLongs = new long?[] { null, 1L, 2L, 3L },
567+
EmptyDecimals = [],
568+
EmptyNullableDecimals = []
454569
}
455570
];
456571
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ private void InsertSecond()
253253
O = new List<long> { 100, 200, 300 },
254254
P = 1.1,
255255
U = -1.234565723762724332233489m,
256-
Z = 10
256+
Z = 10,
257+
NullableW = 8,
258+
NullableX = 9,
259+
NullableY = 10,
260+
NullableZ = 11
257261
};
258262
__collection.InsertOne(root);
259263
}
@@ -333,6 +337,14 @@ public class Root : IRoot
333337
public int Y { get; set; }
334338

335339
public decimal Z { get; set; }
340+
341+
public double? NullableW { get; set; }
342+
343+
public long? NullableX { get; set; }
344+
345+
public int? NullableY { get; set; }
346+
347+
public decimal? NullableZ { get; set; }
336348
}
337349

338350
public class RootDescended : Root

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,106 @@ public async Task AverageAsync_with_selector()
135135
result.Should().Be(61);
136136
}
137137

138+
[Fact]
139+
public void Average_on_empty_set()
140+
{
141+
Action action = () => CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.W).Average();
142+
143+
action.ShouldThrow<InvalidOperationException>().WithMessage("Sequence contains no elements");
144+
}
145+
146+
[Fact]
147+
public void Average_on_empty_set_with_selector()
148+
{
149+
Action action = () => CreateQuery().Where(x => x.A == "__dummy__").Average(x => x.X);
150+
151+
action.ShouldThrow<InvalidOperationException>().WithMessage("Sequence contains no elements");
152+
}
153+
154+
[Fact]
155+
public void AverageAsync_on_empty_set()
156+
{
157+
var subject = CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.Y).AverageAsync();
158+
159+
subject.Awaiting(async q => await q)
160+
.ShouldThrow<InvalidOperationException>()
161+
.WithMessage("Sequence contains no elements");
162+
}
163+
164+
[Fact]
165+
public void AverageAsync_on_empty_set_with_selector()
166+
{
167+
var subject = CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => x.Z);
168+
169+
subject.Awaiting(async q => await q)
170+
.ShouldThrow<InvalidOperationException>()
171+
.WithMessage("Sequence contains no elements");
172+
}
173+
174+
[Fact]
175+
public void Average_on_nullable_empty_set()
176+
{
177+
var result = CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.NullableW).Average();
178+
179+
result.Should().Be(null);
180+
}
181+
182+
[Fact]
183+
public void Average_on_nullable_empty_set_with_selector()
184+
{
185+
var result = CreateQuery().Where(x => x.A == "__dummy__").Average(x => x.NullableX);
186+
187+
result.Should().Be(null);
188+
}
189+
190+
[Fact]
191+
public async Task AverageAsync_on_nullable_empty_set()
192+
{
193+
var result = await CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.NullableY).AverageAsync();
194+
195+
result.Should().Be(null);
196+
}
197+
198+
[Fact]
199+
public async Task AverageAsync_on_nullable_empty_set_with_selector()
200+
{
201+
var result = await CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => x.NullableZ);
202+
203+
result.Should().Be(null);
204+
}
205+
206+
[Fact]
207+
public void Average_on_empty_set_cast_to_nullable()
208+
{
209+
var result = CreateQuery().Where(x => x.A == "__dummy__").Select(x => (double?)x.W).Average();
210+
211+
result.Should().Be(null);
212+
}
213+
214+
[Fact]
215+
public void Average_on_empty_set_cast_to_nullable_with_selector()
216+
{
217+
var result = CreateQuery().Where(x => x.A == "__dummy__").Average(x => (long?)x.X);
218+
219+
result.Should().Be(null);
220+
}
221+
222+
[Fact]
223+
public async Task AverageAsync_on_empty_set_cast_to_nullable()
224+
{
225+
var result = await CreateQuery().Where(x => x.A == "__dummy__").Select(x => (int?)x.Y).AverageAsync();
226+
227+
result.Should().Be(null);
228+
}
229+
230+
[Fact]
231+
public async Task AverageAsync_on_empty_set_cast_to_nullable_with_selector()
232+
{
233+
var result = await CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => (decimal?)x.Z);
234+
235+
result.Should().Be(null);
236+
}
237+
138238
[Fact]
139239
public void GroupBy_combined_with_a_previous_embedded_pipeline()
140240
{

0 commit comments

Comments
 (0)