Skip to content

Commit d8a64df

Browse files
committed
CSHARP-5675: Where possible, return null for average over the empty set
Fixes [CSHARP-5675: Average over empty collection of nullable types throws](https://jira.mongodb.org/browse/CSHARP-5675) For the terminating operator, the fix is to use the SingleOrDefault finalizer when the type is nullable.
1 parent c96b803 commit d8a64df

File tree

3 files changed

+125
-8
lines changed

3 files changed

+125
-8
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/AverageMethodToExecutableQueryTranslator.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ internal static class AverageMethodToExecutableQueryTranslator<TOutput>
3636
// private static fields
3737
private static readonly MethodInfo[] __averageMethods;
3838
private static readonly MethodInfo[] __averageWithSelectorMethods;
39-
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __finalizer = new SingleFinalizer<TOutput>();
39+
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __singleFinalizer = new SingleFinalizer<TOutput>();
40+
private static readonly IExecutableQueryFinalizer<TOutput, TOutput> __singleOrDefaultFinalizer = new SingleOrDefaultFinalizer<TOutput>();
4041

4142
// static constructor
4243
static AverageMethodToExecutableQueryTranslator()
@@ -138,11 +139,11 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
138139

139140
IBsonSerializer outputValueSerializer = expression.GetResultType() switch
140141
{
141-
Type t when t == typeof(int) => new Int32Serializer(),
142-
Type t when t == typeof(long) => new Int64Serializer(),
143-
Type t when t == typeof(float) => new SingleSerializer(),
144-
Type t when t == typeof(double) => new DoubleSerializer(),
145-
Type t when t == typeof(decimal) => new DecimalSerializer(),
142+
Type t when t == typeof(int) => Int32Serializer.Instance,
143+
Type t when t == typeof(long) => Int64Serializer.Instance,
144+
Type t when t == typeof(float) => SingleSerializer.Instance,
145+
Type t when t == typeof(double) => DoubleSerializer.Instance,
146+
Type t when t == typeof(decimal) => DecimalSerializer.Instance,
146147
Type { IsConstructedGenericType: true } t when t.GetGenericTypeDefinition() == typeof(Nullable<>) => (IBsonSerializer)Activator.CreateInstance(typeof(NullableSerializer<>).MakeGenericType(t.GenericTypeArguments[0])),
147148
_ => throw new ExpressionNotSupportedException(expression)
148149
};
@@ -155,10 +156,14 @@ public static ExecutableQuery<TDocument, TOutput> Translate<TDocument>(MongoQuer
155156
AstStage.Project(AstProject.ExcludeId()),
156157
outputWrappedValueSerializer);
157158

159+
var returnType = expression.Type;
160+
158161
return ExecutableQuery.Create(
159162
provider,
160163
pipeline,
161-
__finalizer);
164+
!returnType.IsValueType || returnType.IsNullable()
165+
? __singleOrDefaultFinalizer
166+
: __singleFinalizer);
162167
}
163168

164169
throw new ExpressionNotSupportedException(expression);

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ private void InsertSecond()
253253
O = new List<long> { 100, 200, 300 },
254254
P = 1.1,
255255
U = -1.234565723762724332233489m,
256-
Z = 10
256+
Z = 10,
257+
NullableW = 8,
258+
NullableX = 9,
259+
NullableY = 10,
260+
NullableZ = 11
257261
};
258262
__collection.InsertOne(root);
259263
}
@@ -333,6 +337,14 @@ public class Root : IRoot
333337
public int Y { get; set; }
334338

335339
public decimal Z { get; set; }
340+
341+
public double? NullableW { get; set; }
342+
343+
public long? NullableX { get; set; }
344+
345+
public int? NullableY { get; set; }
346+
347+
public decimal? NullableZ { get; set; }
336348
}
337349

338350
public class RootDescended : Root

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,106 @@ public async Task AverageAsync_with_selector()
135135
result.Should().Be(61);
136136
}
137137

138+
[Fact]
139+
public void Average_on_empty_set()
140+
{
141+
Action action = () => CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.W).Average();
142+
143+
action.ShouldThrow<InvalidOperationException>().WithMessage("Sequence contains no elements");
144+
}
145+
146+
[Fact]
147+
public void Average_on_empty_set_with_selector()
148+
{
149+
Action action = () => CreateQuery().Where(x => x.A == "__dummy__").Average(x => x.X);
150+
151+
action.ShouldThrow<InvalidOperationException>().WithMessage("Sequence contains no elements");
152+
}
153+
154+
[Fact]
155+
public void AverageAsync_on_empty_set()
156+
{
157+
var subject = CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.Y).AverageAsync();
158+
159+
subject.Awaiting(async q => await q)
160+
.ShouldThrow<InvalidOperationException>()
161+
.WithMessage("Sequence contains no elements");
162+
}
163+
164+
[Fact]
165+
public void AverageAsync_on_empty_set_with_selector()
166+
{
167+
var subject = CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => x.Z);
168+
169+
subject.Awaiting(async q => await q)
170+
.ShouldThrow<InvalidOperationException>()
171+
.WithMessage("Sequence contains no elements");
172+
}
173+
174+
[Fact]
175+
public void Average_on_nullable_empty_set()
176+
{
177+
var result = CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.NullableW).Average();
178+
179+
result.Should().Be(null);
180+
}
181+
182+
[Fact]
183+
public void Average_on_nullable_empty_set_with_selector()
184+
{
185+
var result = CreateQuery().Where(x => x.A == "__dummy__").Average(x => x.NullableX);
186+
187+
result.Should().Be(null);
188+
}
189+
190+
[Fact]
191+
public async Task AverageAsync_on_nullable_empty_set()
192+
{
193+
var result = await CreateQuery().Where(x => x.A == "__dummy__").Select(x => x.NullableY).AverageAsync();
194+
195+
result.Should().Be(null);
196+
}
197+
198+
[Fact]
199+
public async Task AverageAsync_on_nullable_empty_set_with_selector()
200+
{
201+
var result = await CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => x.NullableZ);
202+
203+
result.Should().Be(null);
204+
}
205+
206+
[Fact]
207+
public void Average_on_empty_set_cast_to_nullable()
208+
{
209+
var result = CreateQuery().Where(x => x.A == "__dummy__").Select(x => (double?)x.W).Average();
210+
211+
result.Should().Be(null);
212+
}
213+
214+
[Fact]
215+
public void Average_on_empty_set_cast_to_nullable_with_selector()
216+
{
217+
var result = CreateQuery().Where(x => x.A == "__dummy__").Average(x => (long?)x.X);
218+
219+
result.Should().Be(null);
220+
}
221+
222+
[Fact]
223+
public async Task AverageAsync_on_empty_set_cast_to_nullable()
224+
{
225+
var result = await CreateQuery().Where(x => x.A == "__dummy__").Select(x => (int?)x.Y).AverageAsync();
226+
227+
result.Should().Be(null);
228+
}
229+
230+
[Fact]
231+
public async Task AverageAsync_on_empty_set_cast_to_nullable_with_selector()
232+
{
233+
var result = await CreateQuery().Where(x => x.A == "__dummy__").AverageAsync(x => (decimal?)x.Z);
234+
235+
result.Should().Be(null);
236+
}
237+
138238
[Fact]
139239
public void GroupBy_combined_with_a_previous_embedded_pipeline()
140240
{

0 commit comments

Comments
 (0)