Skip to content

Commit 17a55f0

Browse files
committed
CSHARP-5416: GetType not equals comparison translates incorrectly to MQL.
1 parent 6a33c27 commit 17a55f0

File tree

4 files changed

+269
-16
lines changed

4 files changed

+269
-16
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ node.Then is AstConstantExpression constantThenExpression &&
6868
static bool OperatorMapsNullToNull(AstUnaryOperator @operator)
6969
{
7070
return @operator switch
71-
{
71+
{
7272
AstUnaryOperator.ToDecimal => true,
7373
AstUnaryOperator.ToDouble => true,
7474
AstUnaryOperator.ToInt => true,
@@ -364,6 +364,16 @@ static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField)
364364
}
365365
}
366366

367+
public override AstNode VisitNotFilterOperation(AstNotFilterOperation node)
368+
{
369+
if (node.Operation is AstExistsFilterOperation existsFilterOperation)
370+
{
371+
return new AstExistsFilterOperation(!existsFilterOperation.Exists);
372+
}
373+
374+
return base.VisitNotFilterOperation(node);
375+
}
376+
367377
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
368378
{
369379
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ComparisonExpressionToFilterTranslator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ public static AstFilter Translate(TranslationContext context, BinaryExpression e
5656
return CountComparisonExpressionToFilterTranslator.Translate(context, expression, comparisonOperator, countExpression, sizeExpression);
5757
}
5858

59-
if (GetTypeComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, rightExpression))
59+
if (GetTypeComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, comparisonOperator, rightExpression))
6060
{
61-
return GetTypeComparisonExpressionToFilterTranslator.Translate(context, expression, (MethodCallExpression)leftExpression, (ConstantExpression)rightExpression);
61+
return GetTypeComparisonExpressionToFilterTranslator.Translate(context, expression, (MethodCallExpression)leftExpression, comparisonOperator, (ConstantExpression)rightExpression);
6262
}
6363

6464
if (ModuloComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, rightExpression, out var moduloExpression, out var remainderExpression))

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/GetTypeComparisonExpressionToFilterTranslator.cs

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
using System;
1717
using System.Linq.Expressions;
18-
using MongoDB.Bson;
1918
using MongoDB.Bson.Serialization;
2019
using MongoDB.Bson.Serialization.Conventions;
2120
using MongoDB.Bson.Serialization.Serializers;
@@ -30,28 +29,49 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilter
3029
internal static class GetTypeComparisonExpressionToFilterTranslator
3130
{
3231
// caller is responsible for ensuring constant is on the right
33-
public static bool CanTranslate(Expression leftExpression, Expression rightExpression)
32+
public static bool CanTranslate(
33+
Expression leftExpression,
34+
AstComparisonFilterOperator comparisonOperator,
35+
Expression rightExpression)
3436
{
3537
return
3638
leftExpression is MethodCallExpression methodCallExpression &&
3739
methodCallExpression.Method.Is(ObjectMethod.GetType) &&
40+
(comparisonOperator == AstComparisonFilterOperator.Eq || comparisonOperator == AstComparisonFilterOperator.Ne) &&
3841
rightExpression is ConstantExpression;
3942
}
4043

41-
public static AstFilter Translate(TranslationContext context, BinaryExpression expression, MethodCallExpression getTypeExpression, Expression typeConstantExpression)
44+
public static AstFilter Translate(
45+
TranslationContext context,
46+
BinaryExpression expression,
47+
MethodCallExpression getTypeExpression,
48+
AstComparisonFilterOperator comparisonOperator,
49+
Expression typeConstantExpression)
4250
{
43-
var field = ExpressionToFilterFieldTranslator.Translate(context, getTypeExpression.Object);
44-
var nominalType = field.Serializer.ValueType;
45-
var actualType = typeConstantExpression.GetConstantValue<Type>(expression);
51+
if (CanTranslate(getTypeExpression, comparisonOperator, typeConstantExpression))
52+
{
53+
var field = ExpressionToFilterFieldTranslator.Translate(context, getTypeExpression.Object);
54+
var nominalType = field.Serializer.ValueType;
55+
var actualType = typeConstantExpression.GetConstantValue<Type>(expression);
4656

47-
var discriminatorConvention = field.Serializer.GetDiscriminatorConvention();
48-
var discriminatorField = field.SubField(discriminatorConvention.ElementName, BsonValueSerializer.Instance);
57+
var discriminatorConvention = field.Serializer.GetDiscriminatorConvention();
58+
var discriminatorField = field.SubField(discriminatorConvention.ElementName, BsonValueSerializer.Instance);
4959

50-
return discriminatorConvention switch
51-
{
52-
IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstFilter.TypeEquals(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType),
53-
_ => DiscriminatorAstFilter.TypeEquals(discriminatorField, discriminatorConvention, nominalType, actualType),
54-
};
60+
var filter = discriminatorConvention switch
61+
{
62+
IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstFilter.TypeEquals(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType),
63+
_ => DiscriminatorAstFilter.TypeEquals(discriminatorField, discriminatorConvention, nominalType, actualType),
64+
};
65+
66+
return comparisonOperator switch
67+
{
68+
AstComparisonFilterOperator.Eq => filter,
69+
AstComparisonFilterOperator.Ne => AstFilter.Not(filter),
70+
_ => throw new ExpressionNotSupportedException(expression, because: $"comparison operator {comparisonOperator} is not supported")
71+
};
72+
}
73+
74+
throw new ExpressionNotSupportedException(expression);
5575
}
5676
}
5777
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.Bson.Serialization.Attributes;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Linq;
21+
using Xunit;
22+
23+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ExpressionTranslators
24+
{
25+
public class GetTypeComparisonExpressionToFilterTranslatorTests : Linq3IntegrationTest
26+
{
27+
[Fact]
28+
public void Hierarchical_documents_should_be_serialized_as_expected()
29+
{
30+
var collection = GetHierarchicalCollection();
31+
32+
var documents = collection.AsQueryable().As(BsonDocumentSerializer.Instance).ToArray();
33+
34+
documents.Should().HaveCount(3);
35+
documents[0].Should().Be("{ _id : 1, _t : 'HierarchicalBaseClass' }");
36+
documents[1].Should().Be("{ _id : 2, _t : ['HierarchicalBaseClass', 'HierarchicalInheritedClass1'] }");
37+
documents[2].Should().Be("{ _id : 3, _t : ['HierarchicalBaseClass', 'HierarchicalInheritedClass2'] }");
38+
}
39+
40+
[Fact]
41+
public void Hierarchical_Where_GetType_Equals_HierarchicalBaseClass_should_work()
42+
{
43+
var collection = GetHierarchicalCollection();
44+
45+
var queryable = collection.AsQueryable()
46+
.Where(x => x.GetType() == typeof(HierarchicalBaseClass));
47+
48+
var stages = Translate(collection, queryable);
49+
AssertStages(stages, "{ $match : { '_t.0' : { $exists : false }, _t : 'HierarchicalBaseClass' } }");
50+
51+
var results = queryable.ToList();
52+
results.Select(x => x.Id).Should().Equal(1);
53+
}
54+
55+
[Fact]
56+
public void Hierarchical_Where_GetType_Equals_HierarchicalInheritedClass1_should_work()
57+
{
58+
var collection = GetHierarchicalCollection();
59+
60+
var queryable = collection.AsQueryable()
61+
.Where(x => x.GetType() == typeof(HierarchicalInheritedClass1));
62+
63+
var stages = Translate(collection, queryable);
64+
AssertStages(stages, "{ $match : { _t : ['HierarchicalBaseClass', 'HierarchicalInheritedClass1'] } }");
65+
66+
var results = queryable.ToList();
67+
results.Select(x => x.Id).Should().Equal(2);
68+
}
69+
70+
[Fact]
71+
public void Hierarchical_Where_GetType_NotEquals_HierarchicalBaseClass_should_work()
72+
{
73+
var collection = GetHierarchicalCollection();
74+
75+
var queryable = collection.AsQueryable()
76+
.Where(x => x.GetType() != typeof(HierarchicalBaseClass));
77+
78+
var stages = Translate(collection, queryable);
79+
AssertStages(stages, "{ $match : { $nor : [{'_t.0' : { $exists : false }, _t : 'HierarchicalBaseClass' }] } }");
80+
81+
var results = queryable.ToList();
82+
results.Select(x => x.Id).Should().Equal(2, 3);
83+
}
84+
85+
[Fact]
86+
public void Hierarchical_Where_GetType_NotEquals_HierarchicalInheritedClass1_should_work()
87+
{
88+
var collection = GetHierarchicalCollection();
89+
90+
var queryable = collection.AsQueryable()
91+
.Where(x => x.GetType() != typeof(HierarchicalInheritedClass1));
92+
93+
var stages = Translate(collection, queryable);
94+
AssertStages(stages, "{ $match : { _t : { $ne : ['HierarchicalBaseClass', 'HierarchicalInheritedClass1'] } } }");
95+
96+
var results = queryable.ToList();
97+
results.Select(x => x.Id).Should().Equal(1, 3);
98+
}
99+
100+
[Fact]
101+
public void Scalar_documents_should_be_serialized_as_expected()
102+
{
103+
var collection = GetScalarCollection();
104+
105+
var documents = collection.AsQueryable().As(BsonDocumentSerializer.Instance).ToArray();
106+
107+
documents.Should().HaveCount(3);
108+
documents[0].Should().Be("{ _id : 1 }");
109+
documents[1].Should().Be("{ _id : 2, _t : 'ScalarInheritedClass1' }");
110+
documents[2].Should().Be("{ _id : 3, _t : 'ScalarInheritedClass2' }");
111+
}
112+
113+
[Fact]
114+
public void Scalar_Where_GetType_Equals_ScalarBaseClass_should_work()
115+
{
116+
var collection = GetScalarCollection();
117+
118+
var queryable = collection.AsQueryable()
119+
.Where(x => x.GetType() == typeof(ScalarBaseClass));
120+
121+
var stages = Translate(collection, queryable);
122+
AssertStages(stages, "{ $match : { _t : { $exists : false } } }");
123+
124+
var results = queryable.ToList();
125+
results.Select(x => x.Id).Should().Equal(1);
126+
}
127+
128+
[Fact]
129+
public void Scalar_Where_GetType_Equals_ScalarInheritedClass1_should_work()
130+
{
131+
var collection = GetScalarCollection();
132+
133+
var queryable = collection.AsQueryable()
134+
.Where(x => x.GetType() == typeof(ScalarInheritedClass1));
135+
136+
var stages = Translate(collection, queryable);
137+
AssertStages(stages, "{ $match : { _t : 'ScalarInheritedClass1' } }");
138+
139+
var results = queryable.ToList();
140+
results.Select(x => x.Id).Should().Equal(2);
141+
}
142+
143+
[Fact]
144+
public void Scalar_Where_GetType_NotEquals_ScalarBaseClass_should_work()
145+
{
146+
var collection = GetScalarCollection();
147+
148+
var queryable = collection.AsQueryable()
149+
.Where(x => x.GetType() != typeof(ScalarBaseClass));
150+
151+
var stages = Translate(collection, queryable);
152+
AssertStages(stages, "{ $match : { _t : { $exists : true } } }");
153+
154+
var results = queryable.ToList();
155+
results.Select(x => x.Id).Should().Equal(2, 3);
156+
}
157+
158+
[Fact]
159+
public void Scalar_Where_GetType_NotEquals_ScalarInheritedClass1_should_work()
160+
{
161+
var collection = GetScalarCollection();
162+
163+
var queryable = collection.AsQueryable()
164+
.Where(x => x.GetType() != typeof(ScalarInheritedClass1));
165+
166+
var stages = Translate(collection, queryable);
167+
AssertStages(stages, "{ $match : { _t : { $ne : 'ScalarInheritedClass1' } } }");
168+
169+
var results = queryable.ToList();
170+
results.Select(x => x.Id).Should().Equal(1, 3);
171+
}
172+
173+
private IMongoCollection<HierarchicalBaseClass> GetHierarchicalCollection()
174+
{
175+
var collection = GetCollection<HierarchicalBaseClass>("test");
176+
CreateCollection(
177+
collection,
178+
new HierarchicalBaseClass { Id = 1 },
179+
new HierarchicalInheritedClass1 { Id = 2 },
180+
new HierarchicalInheritedClass2 { Id = 3 });
181+
return collection;
182+
}
183+
184+
private IMongoCollection<ScalarBaseClass> GetScalarCollection()
185+
{
186+
var collection = GetCollection<ScalarBaseClass>("test");
187+
CreateCollection(
188+
collection,
189+
new ScalarBaseClass { Id = 1 },
190+
new ScalarInheritedClass1 { Id = 2 },
191+
new ScalarInheritedClass2 { Id = 3 });
192+
return collection;
193+
}
194+
195+
[BsonDiscriminator(RootClass = true)]
196+
[BsonKnownTypes(typeof(HierarchicalInheritedClass1), typeof(HierarchicalInheritedClass2))]
197+
private class HierarchicalBaseClass
198+
{
199+
public int Id { get; set; }
200+
}
201+
202+
private class HierarchicalInheritedClass1 : HierarchicalBaseClass
203+
{
204+
}
205+
206+
private class HierarchicalInheritedClass2 : HierarchicalBaseClass
207+
{
208+
}
209+
210+
private class ScalarBaseClass
211+
{
212+
public int Id { get; set; }
213+
}
214+
215+
private class ScalarInheritedClass1 : ScalarBaseClass
216+
{
217+
}
218+
219+
private class ScalarInheritedClass2 : ScalarBaseClass
220+
{
221+
}
222+
}
223+
}

0 commit comments

Comments
 (0)