Skip to content

Commit a322f26

Browse files
committed
CSHARP-4863: Refactor AggregateMethodToAggregationExpressionTranslator to produce simpler translations when possible.
1 parent 1bd939d commit a322f26

File tree

4 files changed

+92
-17
lines changed

4 files changed

+92
-17
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,21 @@ public static AggregationExpression Translate(TranslationContext context, Method
5353
var funcContext = context.WithSymbols(accumulatorSymbol, itemSymbol);
5454
var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body);
5555

56-
var sourceVar = AstExpression.Var("source");
56+
var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast);
57+
var seedVar = AstExpression.Var("seed");
58+
var restVar = AstExpression.Var("rest");
5759
var ast = AstExpression.Let(
58-
var: AstExpression.VarBinding(sourceVar, sourceTranslation.Ast),
59-
@in: AstExpression.Cond(
60-
@if: AstExpression.Lte(AstExpression.Size(sourceVar), 1),
61-
@then: AstExpression.ArrayElemAt(sourceVar, 0),
62-
@else: AstExpression.Reduce(
63-
input: AstExpression.Slice(sourceVar, 1, int.MaxValue),
64-
initialValue: AstExpression.ArrayElemAt(sourceVar, 0),
65-
@in: funcTranslation.Ast)));
60+
var: sourceVarBinding,
61+
@in: AstExpression.Let(
62+
var1: AstExpression.VarBinding(seedVar, AstExpression.ArrayElemAt(sourceAst, 0)),
63+
var2: AstExpression.VarBinding(restVar, AstExpression.Slice(sourceAst, 1, int.MaxValue)),
64+
@in: AstExpression.Cond(
65+
@if: AstExpression.Eq(AstExpression.Size(restVar), 0),
66+
@then: seedVar,
67+
@else: AstExpression.Reduce(
68+
input: restVar,
69+
initialValue: seedVar,
70+
@in: funcTranslation.Ast))));
6671

6772
return new AggregationExpression(expression, ast, itemSerializer);
6873
}

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4048Tests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public void IGrouping_Aggregate_with_func_of_root_should_return_expected_result(
137137
var expectedStages = new[]
138138
{
139139
"{ $group : { _id : '$_id', _elements : { $push: '$$ROOT' } } }",
140-
"{ $project : { _id : '$_id', Result : { $let : { vars : { source : '$_elements' }, in : { $cond : { if : { $lte : [{ $size : '$$source' }, 1] }, then : { $arrayElemAt : ['$$source', 0] }, else : { $reduce : { input : { $slice : ['$$source', 1, 2147483647] }, initialValue : { $arrayElemAt : ['$$source', 0] }, in : '$$value' } } } } } } } }",
140+
"{ $project : { _id : '$_id', Result : { $let : { vars : { seed : { $arrayElemAt : ['$_elements', 0] }, rest : { $slice : ['$_elements', 1, 2147483647] } }, in : { $cond : { if : { $eq : [{ $size : '$$rest' }, 0] }, then : '$$seed', else : { $reduce : { input : '$$rest', initialValue : '$$seed', in : '$$value' } } } } } } } }",
141141
"{ $sort : { _id : 1 } }"
142142
};
143143
AssertStages(stages, expectedStages);
@@ -162,7 +162,7 @@ public void IGrouping_Aggregate_with_func_of_scalar_should_return_expected_resul
162162
var expectedStages = new[]
163163
{
164164
"{ $group : { _id : '$_id', _elements : { $push: '$X' } } }",
165-
"{ $project : { _id : '$_id', Result : { $let : { vars : { source : '$_elements' }, in : { $cond : { if : { $lte : [{ $size : '$$source' }, 1] }, then : { $arrayElemAt : ['$$source', 0] }, else : { $reduce : { input : { $slice : ['$$source', 1, 2147483647] }, initialValue : { $arrayElemAt : ['$$source', 0] }, in : '$$value' } } } } } } } }",
165+
"{ $project : { _id : '$_id', Result : { $let : { vars : { seed : { $arrayElemAt : ['$_elements', 0] }, rest : { $slice : ['$_elements', 1, 2147483647] } }, in : { $cond : { if : { $eq : [{ $size : '$$rest' }, 0] }, then : '$$seed', else : { $reduce : { input : '$$rest', initialValue : '$$seed', in : '$$value' } } } } } } } }",
166166
"{ $sort : { _id : 1 } }"
167167
};
168168
AssertStages(stages, expectedStages);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.Bson;
19+
using MongoDB.Driver.Linq;
20+
using MongoDB.TestHelpers.XunitExtensions;
21+
using Xunit;
22+
23+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
24+
{
25+
public class AggregateMethodToAggregationExpressionTranslatorTests : Linq3IntegrationTest
26+
{
27+
[Theory]
28+
[ParameterAttributeData]
29+
public void Aggregate_with_func_should_work(
30+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
31+
{
32+
var collection = CreateCollection(linqProvider);
33+
34+
var queryable = collection.AsQueryable()
35+
.Select(x => x.A.Aggregate((x, y) => x * y));
36+
37+
var stages = Translate(collection, queryable);
38+
var results = queryable.ToList();
39+
40+
if (linqProvider == LinqProvider.V2)
41+
{
42+
AssertStages(stages, "{ $project : { __fld0 : { $reduce : { input : '$A', initialValue : 0, in : { $multiply : ['$$value', '$$this'] } } }, _id : 0 } }");
43+
results.Should().Equal(0, 0, 0, 0); // LINQ2 results are wrong
44+
}
45+
else
46+
{
47+
AssertStages(stages, "{ $project : { _v : { $let : { vars : { seed : { $arrayElemAt : ['$A', 0] }, rest : { $slice : ['$A', 1, 2147483647] } }, in : { $cond : { if : { $eq : [{ $size : '$$rest' }, 0] }, then : '$$seed', else : { $reduce : { input : '$$rest', initialValue : '$$seed', in : { $multiply : ['$$value', '$$this'] } } } } } } }, _id : 0 } }");
48+
results.Should().Equal(0, 1, 2, 6); // C# throws exception on empty sequence but MQL returns 0
49+
}
50+
}
51+
52+
private IMongoCollection<C> CreateCollection(LinqProvider linqProvider)
53+
{
54+
var collection = GetCollection<C>("test", linqProvider);
55+
CreateCollection(
56+
GetCollection<BsonDocument>("test"),
57+
BsonDocument.Parse("{ _id : 0, A : [] }"),
58+
BsonDocument.Parse("{ _id : 1, A : [1] }"),
59+
BsonDocument.Parse("{ _id : 2, A : [1, 2] }"),
60+
BsonDocument.Parse("{ _id : 3, A : [1, 2, 3] }"));
61+
return collection;
62+
}
63+
64+
private class C
65+
{
66+
public int Id { get; set; }
67+
public int[] A { get; set; }
68+
}
69+
}
70+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateProjectTranslatorTests.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,20 +1003,20 @@ public void Should_translate_reduce()
10031003
{
10041004
$let :
10051005
{
1006-
vars : { source : '$M' },
1006+
vars : { seed : { $arrayElemAt : ['$M', 0] }, rest : { $slice : ['$M', 1, 2147483647] } },
10071007
in :
10081008
{
10091009
$cond:
10101010
{
1011-
if : { $lte : [ { $size : '$$source' }, 1 ] },
1012-
then: { $arrayElemAt : [ '$$source', 0 ] },
1011+
if : { $eq : [{ $size : '$$rest' }, 0] },
1012+
then: '$$seed',
10131013
else :
10141014
{
10151015
$reduce :
10161016
{
1017-
input : { $slice : [ '$$source', 1, 2147483647 ] },
1018-
initialValue : { $arrayElemAt : [ '$$source', 0 ] },
1019-
in : { $add : [ '$$value', '$$this' ] }
1017+
input : '$$rest',
1018+
initialValue : '$$seed',
1019+
in : { $add : ['$$value', '$$this'] }
10201020
}
10211021
}
10221022
}

0 commit comments

Comments
 (0)