Skip to content

Commit 8fb4450

Browse files
CSHARP-1974: Fix Count with a predicate in aggregate group.
1 parent 08f0112 commit 8fb4450

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

src/MongoDB.Driver/Linq/Processors/AccumulatorBinder.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ private bool TryGetAccumulatorTypeAndArgument(PipelineExpression node, out Accum
134134
if (resultOperator is CountResultOperator)
135135
{
136136
accumulatorType = AccumulatorType.Sum;
137-
argument = Expression.Constant(1);
137+
argument = GetCountAccumulatorArgument(node.Source);
138138
return true;
139139
}
140140
if (resultOperator is FirstResultOperator)
@@ -198,6 +198,29 @@ private bool TryGetAccumulatorTypeAndArgument(PipelineExpression node, out Accum
198198
return false;
199199
}
200200

201+
private Expression GetCountAccumulatorArgument(Expression node)
202+
{
203+
var where = node as WhereExpression;
204+
if (where != null)
205+
{
206+
return Expression.IfThenElse(where.Predicate, Expression.Constant(1), Expression.Constant(0));
207+
}
208+
209+
var document = node as DocumentExpression;
210+
if (document != null)
211+
{
212+
return Expression.Constant(1);
213+
}
214+
215+
var select = node as SelectExpression;
216+
if (select != null)
217+
{
218+
return Expression.Constant(1);
219+
}
220+
221+
throw new NotSupportedException();
222+
}
223+
201224
private Expression GetAccumulatorArgument(Expression node)
202225
{
203226
// we are looking for a Map
@@ -210,4 +233,4 @@ private Expression GetAccumulatorArgument(Expression node)
210233
throw new NotSupportedException();
211234
}
212235
}
213-
}
236+
}

tests/MongoDB.Driver.Tests/Linq/Translators/AggregateGroupTranslatorTests.cs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,9 @@
1717
using System.Collections.Generic;
1818
using System.Linq;
1919
using System.Linq.Expressions;
20-
using System.Threading.Tasks;
2120
using FluentAssertions;
2221
using MongoDB.Bson;
2322
using MongoDB.Bson.Serialization;
24-
using MongoDB.Bson.TestHelpers.XunitExtensions;
25-
using MongoDB.Driver;
26-
using MongoDB.Driver.Core;
2723
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
2824
using MongoDB.Driver.Linq;
2925
using MongoDB.Driver.Linq.Translators;
@@ -125,6 +121,46 @@ public void Should_translate_count()
125121
result.Value.Result.Should().Be(1);
126122
}
127123

124+
[Fact]
125+
public void Should_translate_count_with_a_predicate()
126+
{
127+
var result = Group(x => x.A, g => new { Result = g.Count(x => x.A != "Awesome") });
128+
129+
result.Projection.Should().Be("{ \"_id\" : \"$A\", \"Result\" : { \"$sum\" : { \"$cond\" : [{ \"$ne\" : [\"$A\", \"Awesome\"] }, 1, 0] } } }");
130+
131+
result.Value.Result.Should().Be(1);
132+
}
133+
134+
[Fact]
135+
public void Should_translate_where_with_a_predicate_and_count()
136+
{
137+
var result = Group(x => x.A, g => new { Result = g.Where(x => x.A != "Awesome").Count() });
138+
139+
result.Projection.Should().Be("{ \"_id\" : \"$A\", \"Result\" : { \"$sum\" : { \"$cond\" : [{ \"$ne\" : [\"$A\", \"Awesome\"] }, 1, 0] } } }");
140+
141+
result.Value.Result.Should().Be(1);
142+
}
143+
144+
[Fact]
145+
public void Should_translate_where_select_and_count_with_predicates()
146+
{
147+
var result = Group(x => x.A, g => new { Result = g.Select(x => new { A = x.A }).Count(x => x.A != "Awesome") });
148+
149+
result.Projection.Should().Be("{ \"_id\" : \"$A\", \"Result\" : { \"$sum\" : { \"$cond\" : [{ \"$ne\" : [\"$A\", \"Awesome\"] }, 1, 0] } } }");
150+
151+
result.Value.Result.Should().Be(1);
152+
}
153+
154+
[Fact]
155+
public void Should_translate_where_select_with_predicate_and_count()
156+
{
157+
var result = Group(x => x.A, g => new { Result = g.Select(x => new { A = x.A }).Count() });
158+
159+
result.Projection.Should().Be("{ \"_id\" : \"$A\", \"Result\" : { \"$sum\" : 1 } }");
160+
161+
result.Value.Result.Should().Be(1);
162+
}
163+
128164
[Fact]
129165
public void Should_translate_long_count()
130166
{
@@ -394,4 +430,4 @@ private class ProjectedResult<T>
394430
public T Value;
395431
}
396432
}
397-
}
433+
}

0 commit comments

Comments
 (0)