Skip to content

Commit 0dd40f6

Browse files
authored
CSHARP-4696: Expression not supported: Math.Round (#1124)
1 parent aada8e8 commit 0dd40f6

File tree

6 files changed

+216
-0
lines changed

6 files changed

+216
-0
lines changed

src/MongoDB.Driver.Core/Core/Misc/Feature.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ public class Feature
110110
private static readonly Feature __regexMatch = new Feature("RegexMatch", WireVersion.Server42);
111111
private static readonly Feature __retryableReads = new Feature("RetryableReads", WireVersion.Server36);
112112
private static readonly Feature __retryableWrites = new Feature("RetryableWrites", WireVersion.Server36);
113+
private static readonly Feature __round = new Feature("Round", WireVersion.Server42);
113114
private static readonly Feature __scramSha1Authentication = new Feature("ScramSha1Authentication", WireVersion.Server30);
114115
private static readonly Feature __scramSha256Authentication = new Feature("ScramSha256Authentication", WireVersion.Server40);
115116
private static readonly Feature __serverExtractsUsernameFromX509Certificate = new Feature("ServerExtractsUsernameFromX509Certificate", WireVersion.Server34);
@@ -619,6 +620,11 @@ public class Feature
619620
[Obsolete("This property will be removed in a later release.")]
620621
public static Feature RetryableWrites => __retryableWrites;
621622

623+
/// <summary>
624+
/// Gets the $round feature.
625+
/// </summary>
626+
public static Feature Round => __round;
627+
622628
/// <summary>
623629
/// Gets the scram sha1 authentication feature.
624630
/// </summary>

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,16 @@ public static AstExpression ReverseArray(AstExpression array)
668668
return new AstUnaryExpression(AstUnaryOperator.ReverseArray, array);
669669
}
670670

671+
public static AstExpression Round(AstExpression arg)
672+
{
673+
return new AstUnaryExpression(AstUnaryOperator.Round, arg);
674+
}
675+
676+
public static AstExpression Round(AstExpression arg, AstExpression place)
677+
{
678+
return new AstBinaryExpression(AstBinaryOperator.Round, arg, place);
679+
}
680+
671681
public static AstExpression RTrim(AstExpression input, AstExpression chars = null)
672682
{
673683
return new AstRTrimExpression(input, chars);

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MathMethod.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ internal static class MathMethod
4646
private static readonly MethodInfo __logWithNewBase;
4747
private static readonly MethodInfo __log10;
4848
private static readonly MethodInfo __pow;
49+
private static readonly MethodInfo __roundWithDecimal;
50+
private static readonly MethodInfo __roundWithDecimalAndDecimals;
51+
private static readonly MethodInfo __roundWithDouble;
52+
private static readonly MethodInfo __roundWithDoubleAndDigits;
4953
private static readonly MethodInfo __sin;
5054
private static readonly MethodInfo __sinh;
5155
private static readonly MethodInfo __sqrt;
@@ -88,6 +92,10 @@ static MathMethod()
8892
__logWithNewBase = ReflectionInfo.Method((double a, double newBase) => Math.Log(a, newBase));
8993
__log10 = ReflectionInfo.Method((double d) => Math.Log10(d));
9094
__pow = ReflectionInfo.Method((double x, double y) => Math.Pow(x, y));
95+
__roundWithDecimal = ReflectionInfo.Method((decimal d) => Math.Round(d));
96+
__roundWithDecimalAndDecimals = ReflectionInfo.Method((decimal d, int decimals) => Math.Round(d, decimals));
97+
__roundWithDouble = ReflectionInfo.Method((double d) => Math.Round(d));
98+
__roundWithDoubleAndDigits = ReflectionInfo.Method((double d, int digits) => Math.Round(d, digits));
9199
__sin = ReflectionInfo.Method((double a) => Math.Sin(a));
92100
__sinh = ReflectionInfo.Method((double a) => Math.Sinh(a));
93101
__sqrt = ReflectionInfo.Method((double d) => Math.Sqrt(d));
@@ -123,6 +131,10 @@ static MathMethod()
123131
public static MethodInfo LogWithNewBase => __logWithNewBase;
124132
public static MethodInfo Log10 => __log10;
125133
public static MethodInfo Pow => __pow;
134+
public static MethodInfo RoundWithDecimal => __roundWithDecimal;
135+
public static MethodInfo RoundWithDecimalAndDecimals => __roundWithDecimalAndDecimals;
136+
public static MethodInfo RoundWithDouble => __roundWithDouble;
137+
public static MethodInfo RoundWithDoubleAndDigits => __roundWithDoubleAndDigits;
126138
public static MethodInfo Sin => __sin;
127139
public static MethodInfo Sinh => __sinh;
128140
public static MethodInfo Sqrt => __sqrt;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
6464
case "Range": return RangeMethodToAggregationExpressionTranslator.Translate(context, expression);
6565
case "Rank": return RankMethodToAggregationExpressionTranslator.Translate(context, expression);
6666
case "Reverse": return ReverseMethodToAggregationExpressionTranslator.Translate(context, expression);
67+
case "Round": return RoundMethodToAggregationExpressionTranslator.Translate(context, expression);
6768
case "Select": return SelectMethodToAggregationExpressionTranslator.Translate(context, expression);
6869
case "SetEquals": return SetEqualsMethodToAggregationExpressionTranslator.Translate(context, expression);
6970
case "Shift": return ShiftMethodToAggregationExpressionTranslator.Translate(context, expression);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq.Expressions;
17+
using System.Reflection;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
21+
22+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
23+
{
24+
internal static class RoundMethodToAggregationExpressionTranslator
25+
{
26+
private static readonly MethodInfo[] __roundMethods =
27+
{
28+
MathMethod.RoundWithDecimal,
29+
MathMethod.RoundWithDecimalAndDecimals,
30+
MathMethod.RoundWithDouble,
31+
MathMethod.RoundWithDoubleAndDigits
32+
};
33+
34+
private static readonly MethodInfo[] __roundWithPlaceMethods =
35+
{
36+
MathMethod.RoundWithDecimalAndDecimals,
37+
MathMethod.RoundWithDoubleAndDigits
38+
};
39+
40+
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
41+
{
42+
var method = expression.Method;
43+
var arguments = expression.Arguments;
44+
45+
if (method.IsOneOf(__roundMethods))
46+
{
47+
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
48+
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
49+
50+
AstExpression ast;
51+
if (method.IsOneOf(__roundWithPlaceMethods))
52+
{
53+
var placeExpression = arguments[1];
54+
var placeTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, placeExpression);
55+
ast = AstExpression.Round(argumentTranslation.Ast, placeTranslation.Ast);
56+
}
57+
else
58+
{
59+
ast = AstExpression.Round(argumentTranslation.Ast);
60+
}
61+
62+
return new AggregationExpression(expression, ast, argumentTranslation.Serializer);
63+
}
64+
65+
throw new ExpressionNotSupportedException(expression);
66+
}
67+
}
68+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq;
18+
using FluentAssertions;
19+
using MongoDB.Bson;
20+
using MongoDB.Bson.Serialization.Attributes;
21+
using MongoDB.Driver.Core.Misc;
22+
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
23+
using Xunit;
24+
25+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
26+
{
27+
public class RoundMethodToAggregationExpressionTranslatorTests : Linq3IntegrationTest
28+
{
29+
[Fact]
30+
public void Math_round_double_should_work()
31+
{
32+
RequireServer.Check().Supports(Feature.Round);
33+
34+
var collection = CreateCollection();
35+
var queryable = collection.AsQueryable()
36+
.Select(i => Math.Round(i.Double));
37+
38+
var stages = Translate(collection, queryable);
39+
AssertStages(
40+
stages,
41+
"{ $project : { _v : { $round : '$Double' }, _id : 0 } }");
42+
43+
var results = queryable.ToList();
44+
results.Should().Equal(10.0, 10.0, 9.0);
45+
}
46+
47+
[Fact]
48+
public void Math_round_double_with_digits_should_work()
49+
{
50+
RequireServer.Check().Supports(Feature.Round);
51+
52+
var collection = CreateCollection();
53+
var queryable = collection.AsQueryable()
54+
.Select(i => Math.Round(i.Double, 1));
55+
56+
var stages = Translate(collection, queryable);
57+
AssertStages(
58+
stages,
59+
"{ $project : { _v : { $round : ['$Double', 1] }, _id : 0 } }");
60+
61+
var results = queryable.ToList();
62+
results.Should().Equal(10.2, 9.7, 9.2);
63+
}
64+
65+
[Fact]
66+
public void Math_round_decimal_should_work()
67+
{
68+
RequireServer.Check().Supports(Feature.Round);
69+
70+
var collection = CreateCollection();
71+
var queryable = collection.AsQueryable()
72+
.Select(i => Math.Round(i.Decimal));
73+
74+
var stages = Translate(collection, queryable);
75+
AssertStages(
76+
stages,
77+
"{ $project : { _v : { $round : '$Decimal' }, _id : 0 } }");
78+
79+
var results = queryable.ToList();
80+
results.Should().Equal(10.0m, 10.0m, 9.0m);
81+
}
82+
83+
[Fact]
84+
public void Math_round_decimal_with_decimals_should_work()
85+
{
86+
RequireServer.Check().Supports(Feature.Round);
87+
88+
var collection = CreateCollection();
89+
var queryable = collection.AsQueryable()
90+
.Select(i => Math.Round(i.Decimal, 1));
91+
92+
var stages = Translate(collection, queryable);
93+
AssertStages(
94+
stages,
95+
"{ $project : { _v : { $round : ['$Decimal', 1] }, _id : 0 } }");
96+
97+
var results = queryable.ToList();
98+
results.Should().Equal(10.2m, 9.7m, 9.2m);
99+
}
100+
101+
private IMongoCollection<Data> CreateCollection()
102+
{
103+
var collection = GetCollection<Data>("test");
104+
CreateCollection(
105+
collection,
106+
new Data { Double = 10.234, Decimal = 10.234m },
107+
new Data { Double = 9.66, Decimal = 9.66m },
108+
new Data { Double = 9.2, Decimal = 9.2m });
109+
return collection;
110+
}
111+
112+
private class Data
113+
{
114+
public double Double { get; set; }
115+
[BsonRepresentation(BsonType.Decimal128)]
116+
public decimal Decimal { get; set; }
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)