Skip to content

Commit da9d055

Browse files
rstamDmitryLukyanov
authored andcommitted
CSHARP-4517: Values being compared must be compatible.
1 parent 65e64d2 commit da9d055

File tree

7 files changed

+210
-6
lines changed

7 files changed

+210
-6
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/SerializationHelper.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,37 @@
1414
*/
1515

1616
using System.Collections;
17+
using System.Linq.Expressions;
1718
using MongoDB.Bson;
1819
using MongoDB.Bson.IO;
1920
using MongoDB.Bson.Serialization;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
2022

2123
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2224
{
2325
internal static class SerializationHelper
2426
{
27+
public static BsonValue SerializeValue(IBsonSerializer serializer, ConstantExpression constantExpression, Expression containingExpression)
28+
{
29+
var value = constantExpression.Value;
30+
if (value == null || serializer.ValueType.IsAssignableFrom(value.GetType()))
31+
{
32+
return SerializeValue(serializer, value);
33+
}
34+
35+
if (value.GetType().ImplementsIEnumerable(out var itemType) &&
36+
serializer is IBsonArraySerializer arraySerializer &&
37+
arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo) &&
38+
itemSerializationInfo.Serializer is var itemSerializer &&
39+
itemSerializer.ValueType.IsAssignableFrom(itemType))
40+
{
41+
var ienumerableSerializer = IEnumerableSerializer.Create(itemSerializer);
42+
return SerializeValue(ienumerableSerializer, value);
43+
}
44+
45+
throw new ExpressionNotSupportedException(constantExpression, containingExpression, because: "it was not possible to determine how to serialize the constant");
46+
}
47+
2548
public static BsonValue SerializeValue(IBsonSerializer serializer, object value)
2649
{
2750
var document = new BsonDocument();

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ public static bool Implements(this Type type, Type @interface)
5858
return false;
5959
}
6060

61+
public static bool ImplementsIEnumerable(this Type type, out Type itemType)
62+
{
63+
if (TryGetIEnumerableGenericInterface(type, out var ienumerableType))
64+
{
65+
itemType = ienumerableType.GetGenericArguments()[0];
66+
return true;
67+
}
68+
69+
itemType = null;
70+
return false;
71+
}
72+
6173
public static bool Is(this Type type, Type comparand)
6274
{
6375
if (type == comparand)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ public static AggregationExpression Translate(TranslationContext context, Binary
3737

3838
var leftExpression = expression.Left;
3939
var rightExpression = expression.Right;
40+
41+
if (!AreOperandTypesCompatible(expression, leftExpression, rightExpression))
42+
{
43+
throw new ExpressionNotSupportedException(expression, because: "operand types are not compatible with each other");
44+
}
45+
4046
if (IsArithmeticExpression(expression))
4147
{
4248
leftExpression = ConvertHelper.RemoveWideningConvert(leftExpression);
@@ -48,8 +54,22 @@ public static AggregationExpression Translate(TranslationContext context, Binary
4854
return TranslateEnumExpression(context, expression);
4955
}
5056

51-
var leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
52-
var rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
57+
AggregationExpression leftTranslation, rightTranslation;
58+
if (leftExpression is ConstantExpression leftConstantExpresion)
59+
{
60+
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
61+
leftTranslation = TranslateConstant(expression, leftConstantExpresion, rightTranslation.Serializer);
62+
}
63+
else if (rightExpression is ConstantExpression rightConstantExpression)
64+
{
65+
leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
66+
rightTranslation = TranslateConstant(expression, rightConstantExpression, leftTranslation.Serializer);
67+
}
68+
else
69+
{
70+
leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
71+
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
72+
}
5373

5474
var ast = expression.NodeType switch
5575
{
@@ -96,6 +116,29 @@ public static AggregationExpression Translate(TranslationContext context, Binary
96116
return new AggregationExpression(expression, ast, serializer);
97117
}
98118

119+
public static bool AreOperandTypesCompatible(Expression expression, Expression leftExpression, Expression rightExpression)
120+
{
121+
if (leftExpression is ConstantExpression leftConstantExpression &&
122+
leftConstantExpression.Value == null)
123+
{
124+
return true;
125+
}
126+
127+
if (rightExpression is ConstantExpression rightConstantExpression &&
128+
rightConstantExpression.Value == null)
129+
{
130+
return true;
131+
}
132+
133+
if (leftExpression.Type.IsAssignableFrom(rightExpression.Type) ||
134+
rightExpression.Type.IsAssignableFrom(leftExpression.Type))
135+
{
136+
return true;
137+
}
138+
139+
return false;
140+
}
141+
99142
private static bool IsAddOrSubtractExpression(Expression expression)
100143
{
101144
return expression.NodeType switch
@@ -190,6 +233,13 @@ private static AstBinaryOperator ToBinaryOperator(ExpressionType nodeType)
190233
};
191234
}
192235

236+
private static AggregationExpression TranslateConstant(BinaryExpression containingExpression, ConstantExpression constantExpression, IBsonSerializer otherSerializer)
237+
{
238+
var serializedValue = SerializationHelper.SerializeValue(otherSerializer, constantExpression, containingExpression);
239+
var ast = AstExpression.Constant(serializedValue);
240+
return new AggregationExpression(constantExpression, ast, otherSerializer);
241+
}
242+
193243
private static AggregationExpression TranslateEnumExpression(TranslationContext context, BinaryExpression expression)
194244
{
195245
var leftExpression = expression.Left;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ComparisonExpressionToFilterTranslator.cs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
using System.Linq.Expressions;
1717
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
18-
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
1918
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
2020
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators;
2121
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators;
2222

@@ -66,17 +66,26 @@ public static AstFilter Translate(TranslationContext context, BinaryExpression e
6666
return filter;
6767
}
6868

69-
var comparand = rightExpression.GetConstantValue<object>(containingExpression: expression);
69+
if (!BinaryExpressionToAggregationExpressionTranslator.AreOperandTypesCompatible(expression, leftExpression, rightExpression))
70+
{
71+
throw new ExpressionNotSupportedException(expression, because: "operand types are not compatible with each other");
72+
}
73+
74+
var comparandExpression = rightExpression as ConstantExpression;
75+
if (comparandExpression == null)
76+
{
77+
throw new ExpressionNotSupportedException(expression, because: "comparand must be a constant");
78+
}
7079

7180
if (leftExpression.Type == typeof(bool) &&
7281
(comparisonOperator == AstComparisonFilterOperator.Eq || comparisonOperator == AstComparisonFilterOperator.Ne) &&
7382
rightExpression.Type == typeof(bool))
7483
{
75-
return TranslateComparisonToBooleanConstant(context, expression, leftExpression, comparisonOperator, (bool)comparand);
84+
return TranslateComparisonToBooleanConstant(context, expression, leftExpression, comparisonOperator, (bool)comparandExpression.Value);
7685
}
7786

7887
var field = ExpressionToFilterFieldTranslator.Translate(context, leftExpression);
79-
var serializedComparand = SerializationHelper.SerializeValue(field.Serializer, comparand);
88+
var serializedComparand = SerializationHelper.SerializeValue(field.Serializer, comparandExpression, expression);
8089
return AstFilter.Compare(field, comparisonOperator, serializedComparand);
8190
}
8291

tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTests/MongoQueryableIntArrayComparedToEnumerableIntTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class C
4343
public void Where_operator_equal_should_render_correctly(IEnumerable<int> value, string expectedFilter)
4444
{
4545
var subject = __collection.AsQueryable();
46+
value = new List<int>(value); // not an array
4647

4748
var queryable = subject.Where(x => x.A == value);
4849

@@ -54,6 +55,7 @@ public void Where_operator_equal_should_render_correctly(IEnumerable<int> value,
5455
public void Where_operator_not_equal_should_render_correctly(IEnumerable<int> value, string expectedFilter)
5556
{
5657
var subject = __collection.AsQueryable();
58+
value = new List<int>(value); // not an array
5759

5860
var queryable = subject.Where(x => x.A != value);
5961

tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableIntArrayComparedToEnumerableIntTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class C
4646
public void Where_operator_equal_should_render_correctly(IEnumerable<int> value, string expectedFilter)
4747
{
4848
var subject = __collection.AsQueryable();
49+
value = new List<int>(value); // not an array
4950

5051
var queryable = subject.Where(x => x.A == value);
5152

@@ -57,6 +58,7 @@ public void Where_operator_equal_should_render_correctly(IEnumerable<int> value,
5758
public void Where_operator_not_equal_should_render_correctly(IEnumerable<int> value, string expectedFilter)
5859
{
5960
var subject = __collection.AsQueryable();
61+
value = new List<int>(value); // not an array
6062

6163
var queryable = subject.Where(x => x.A != value);
6264

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq;
18+
using System.Linq.Expressions;
19+
using FluentAssertions;
20+
using MongoDB.Bson.Serialization;
21+
using MongoDB.Bson.Serialization.Attributes;
22+
using MongoDB.Bson.Serialization.Serializers;
23+
using MongoDB.Driver.Linq;
24+
using Xunit;
25+
26+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
27+
{
28+
public class CSharp4517Tests : Linq3IntegrationTest
29+
{
30+
[Fact]
31+
public void Filter_with_comparison_of_different_types_should_throw()
32+
{
33+
var collection = CreateCollection();
34+
35+
var queryable =
36+
collection.AsQueryable()
37+
.Where(x => x.Id == 1);
38+
39+
var exception = Record.Exception(() => Translate(collection, queryable));
40+
41+
exception.Should().BeOfType<ExpressionNotSupportedException>();
42+
exception.Message.Should().Contain("because operand types are not compatible with each other");
43+
}
44+
45+
[Fact]
46+
public void Expression_with_comparison_of_different_types_should_throw()
47+
{
48+
var collection = CreateCollection();
49+
50+
var queryable =
51+
collection.AsQueryable()
52+
.Select(x => new { R = x.Id == 1 });
53+
54+
var exception = Record.Exception(() => Translate(collection, queryable));
55+
56+
exception.Should().BeOfType<ExpressionNotSupportedException>();
57+
exception.Message.Should().Contain("because operand types are not compatible with each other");
58+
}
59+
60+
private IMongoCollection<MyDocument> CreateCollection()
61+
{
62+
var collection = GetCollection<MyDocument>("test");
63+
return collection;
64+
}
65+
66+
public class MyDocument
67+
{
68+
public MyId Id { get; set; }
69+
public string Name { get; set; }
70+
}
71+
72+
[BsonSerializer(typeof(MyIdSerializer))]
73+
#pragma warning disable CS0660 // Type defines operator == or operator != but does not override Object.Equals(object o)
74+
#pragma warning disable CS0661 // Type defines operator == or operator != but does not override Object.GetHashCode()
75+
public class MyId
76+
#pragma warning restore CS0661 // Type defines operator == or operator != but does not override Object.GetHashCode()
77+
#pragma warning restore CS0660 // Type defines operator == or operator != but does not override Object.Equals(object o)
78+
{
79+
80+
public MyId(int id)
81+
{
82+
Id = id;
83+
}
84+
85+
public int Id { get; }
86+
87+
public static bool operator ==(int id, MyId other) => id == other.Id;
88+
public static bool operator ==(MyId id, int other) => id.Id == other;
89+
public static bool operator !=(int id, MyId other) => !(id == other);
90+
public static bool operator !=(MyId id, int other) => !(id == other);
91+
}
92+
93+
public class MyIdSerializer : SerializerBase<MyId>
94+
{
95+
public override MyId Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
96+
{
97+
return new MyId(context.Reader.ReadInt32());
98+
}
99+
100+
public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, MyId value)
101+
{
102+
context.Writer.WriteInt32(value.Id);
103+
}
104+
}
105+
}
106+
}

0 commit comments

Comments
 (0)