Skip to content

Commit 208438f

Browse files
committed
CSHARP-4744: Improve optimization of Count with predicate in Group.
1 parent a88dc2e commit 208438f

File tree

5 files changed

+113
-24
lines changed

5 files changed

+113
-24
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CountMethodToAggregationExpressionTranslator.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,17 @@ public static AggregationExpression Translate(TranslationContext context, Method
6565
}
6666

6767
var predicateLambda = (LambdaExpression)arguments[1];
68-
var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
69-
var predicateTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, predicateLambda, sourceItemSerializer, asRoot: false);
70-
var filteredSourceAst = AstExpression.Filter(
71-
input: sourceTranslation.Ast,
72-
cond: predicateTranslation.Ast,
73-
@as: predicateLambda.Parameters[0].Name);
74-
ast = AstExpression.Size(filteredSourceAst);
68+
var predicateParameter = predicateLambda.Parameters[0];
69+
var predicateParameterSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
70+
var predicateSymbol = context.CreateSymbol(predicateParameter, predicateParameterSerializer);
71+
var predicateContext = context.WithSymbol(predicateSymbol);
72+
var predicateTranslation = ExpressionToAggregationExpressionTranslator.Translate(predicateContext, predicateLambda.Body);
73+
74+
ast = AstExpression.Sum(
75+
AstExpression.Map(
76+
input: sourceTranslation.Ast,
77+
@as: predicateSymbol.Var,
78+
@in: AstExpression.Cond(predicateTranslation.Ast, 1, 0)));
7579
}
7680
else
7781
{

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ public void IGrouping_Count_with_predicate_of_root_should_work()
686686
var stages = Translate(collection, queryable);
687687
var expectedStages = new[]
688688
{
689-
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }", // MQL could be optimized further
690-
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e.X', 1] } } } } } }",
689+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
690+
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
691691
"{ $sort : { _id : 1 } }"
692692
};
693693
AssertStages(stages, expectedStages);
@@ -711,8 +711,8 @@ public void IGrouping_Count_with_predicate_of_scalar_should_work()
711711
var stages = Translate(collection, queryable);
712712
var expectedStages = new[]
713713
{
714-
"{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further
715-
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e', 1] } } } } } }",
714+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
715+
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
716716
"{ $sort : { _id : 1 } }"
717717
};
718718
AssertStages(stages, expectedStages);
@@ -1376,8 +1376,8 @@ public void IGrouping_LongCount_with_predicate_of_root_should_work()
13761376
var stages = Translate(collection, queryable);
13771377
var expectedStages = new[]
13781378
{
1379-
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }", // MQL could be optimized further
1380-
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e.X', 1] } } } } } }",
1379+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
1380+
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
13811381
"{ $sort : { _id : 1 } }"
13821382
};
13831383
AssertStages(stages, expectedStages);
@@ -1401,8 +1401,8 @@ public void IGrouping_LongCount_with_predicate_of_scalar_should_work()
14011401
var stages = Translate(collection, queryable);
14021402
var expectedStages = new[]
14031403
{
1404-
"{ $group : { _id : '$_id', _elements : { $push : '$X' } } }", // MQL could be optimized further
1405-
"{ $project : { _id : '$_id', Result : { $size : { $filter : { input : '$_elements', as : 'e', cond : { $eq : ['$$e', 1] } } } } } }",
1404+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $eq : ['$X', 1] }, then : 1, else : 0 } } } } }",
1405+
"{ $project : { _id : '$_id', Result : '$__agg0' } }",
14061406
"{ $sort : { _id : 1 } }"
14071407
};
14081408
AssertStages(stages, expectedStages);

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4063Tests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ public void GroupBy_with_bool_should_work()
3333
var stages = Translate(collection, queryable);
3434
var expectedStages = new[]
3535
{
36-
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }",
37-
"{ $project : { Value : { $size : { $filter : { input : '$_elements', as : 'x', cond : '$$x.Bool' } } }, _id : 0 } }"
36+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : '$Bool', then : 1, else : 0 } } } } }",
37+
"{ $project : { Value : '$__agg0', _id : 0 } }"
3838
};
3939
AssertStages(stages, expectedStages);
4040
}
@@ -52,8 +52,8 @@ public void GroupBy_with_nullable_bool_should_work()
5252
var stages = Translate(collection, queryable);
5353
var expectedStages = new[]
5454
{
55-
"{ $group : { _id : '$_id', _elements : { $push : '$$ROOT' } } }",
56-
"{ $project : { Value : { $size : { $filter : { input : '$_elements', as : 'x', cond : { $and : [{ $ne : ['$$x.NullableBool', null] }, '$$x.NullableBool'] } } } }, _id : 0 } }"
55+
"{ $group : { _id : '$_id', __agg0 : { $sum : { $cond : { if : { $and : [{ $ne : ['$NullableBool', null] }, '$NullableBool'] }, then : 1, else : 0 } } } } }",
56+
"{ $project : { Value : '$__agg0', _id : 0 } }"
5757
};
5858
AssertStages(stages, expectedStages);
5959
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.Bson;
19+
using MongoDB.Bson.Serialization.Attributes;
20+
using MongoDB.Driver.Linq;
21+
using MongoDB.TestHelpers.XunitExtensions;
22+
using Xunit;
23+
24+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
25+
{
26+
public class CSharp4744Tests : Linq3IntegrationTest
27+
{
28+
[Theory]
29+
[ParameterAttributeData]
30+
public void ReplaceOne(
31+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
32+
{
33+
var collection = GetCollection(linqProvider);
34+
35+
var queryable = collection.AsQueryable()
36+
.GroupBy(x => x.FooName, (x, y) => new Summary()
37+
{
38+
FooName = x,
39+
Count = y.Count(x => x.State == State.Running)
40+
});
41+
42+
var stages = Translate(collection, queryable);
43+
if (linqProvider == LinqProvider.V2)
44+
{
45+
AssertStages(
46+
stages,
47+
"{ $group: { _id : '$FooName', Count : { $sum : { $cond : [{ $eq : ['$State', 1] }, 1, 0] } } } }"); // note: 1 instead of "Running" is an error
48+
}
49+
else
50+
{
51+
AssertStages(
52+
stages,
53+
"{ $group: { _id : '$FooName', __agg0 : { $sum : { $cond : { if : { $eq : ['$State', 'Running'] }, then : 1, else : 0 } } } } }",
54+
"{ $project : { FooName : '$_id', Count : '$__agg0', _id : 0 } }");
55+
}
56+
}
57+
58+
private IMongoCollection<Foo> GetCollection(LinqProvider linqProvider)
59+
{
60+
var collection = GetCollection<Foo>("test", linqProvider);
61+
CreateCollection(collection);
62+
return collection;
63+
}
64+
65+
public enum State
66+
{
67+
Started,
68+
Running,
69+
Complete
70+
}
71+
72+
public class Foo
73+
{
74+
public string FooName;
75+
[BsonRepresentation(BsonType.String)]
76+
public State State;
77+
}
78+
79+
public class Summary
80+
{
81+
public string FooName;
82+
public int Count;
83+
}
84+
}
85+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ public void Should_translate_count_with_a_predicate()
160160

161161
AssertStages(
162162
result.Stages,
163-
"{ $group : { _id : '$A', _elements : { $push : '$$ROOT' } } }",
164-
"{ $project : { Result : { $size : { $filter : { input : '$_elements', as : 'x', cond : { $ne : ['$$x.A', 'Awesome' ] } } } }, _id : 0 } }");
163+
"{ $group : { _id : '$A', __agg0 : { $sum : { $cond : { if : { $ne : ['$A', 'Awesome'] }, then : 1, else : 0 } } } } }",
164+
"{ $project : { Result : '$__agg0', _id : 0 } }");
165165

166166
result.Value.Result.Should().Be(1);
167167
}
@@ -182,12 +182,12 @@ public void Should_translate_where_with_a_predicate_and_count()
182182
[Fact]
183183
public void Should_translate_where_select_and_count_with_predicates()
184184
{
185-
var result = Group(x => x.A, g => new { Result = g.Select(x => new { A = x.A }).Count(x => x.A != "Awesome") });
185+
var result = Group(x => x.A, g => new { Result = g.Select(x => new { B = x.A }).Count(x => x.B != "Awesome") });
186186

187187
AssertStages(
188188
result.Stages,
189-
"{ $group : { _id : '$A', __agg0 : { $push : { A : '$A' } } } }",
190-
"{ $project : { Result : { $size : { $filter : { input : '$__agg0', as : 'x', cond : { $ne : ['$$x.A', 'Awesome'] } } } }, _id : 0 } }");
189+
"{ $group : { _id : '$A', __agg0 : { $push : { B : '$A' } } } }",
190+
"{ $project : { Result : { $sum : { $map : { input : '$__agg0', as : 'x', in : { $cond : { if : { $ne : ['$$x.B', 'Awesome'] }, then : 1, else : 0 } } } } }, _id : 0 } }");
191191

192192
result.Value.Result.Should().Be(1);
193193
}

0 commit comments

Comments
 (0)