Skip to content

Commit a24f1cd

Browse files
committed
CSHARP-4691: Support GetType comparison in LINQ3.
1 parent 6c20aa7 commit a24f1cd

File tree

11 files changed

+870
-4
lines changed

11 files changed

+870
-4
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ public static AstExpression IndexOfCP(AstExpression @string, AstExpression value
453453
return new AstIndexOfCPExpression(@string, value, start, end);
454454
}
455455

456+
public static AstExpression IsArray(AstExpression value)
457+
{
458+
return new AstUnaryExpression(AstUnaryOperator.IsArray, value);
459+
}
460+
456461
public static AstExpression Last(AstExpression array)
457462
{
458463
return new AstUnaryExpression(AstUnaryOperator.Last, array);

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Filters/AstFilter.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ public static AstFilter Not(AstFilter filter)
195195
return new AstNorFilter(new[] { filter });
196196
}
197197

198+
public static AstFieldOperationFilter NotExists(AstFilterField field)
199+
{
200+
return new AstFieldOperationFilter(field, new AstExistsFilterOperation(exists: false));
201+
}
202+
198203
public static AstFilter Or(params AstFilter[] filters)
199204
{
200205
Ensure.IsNotNull(filters, nameof(filters));
@@ -235,6 +240,11 @@ public static AstFieldOperationFilter Size(AstFilterField field, BsonValue size)
235240
{
236241
return new AstFieldOperationFilter(field, new AstSizeFilterOperation(size));
237242
}
243+
244+
public static AstFieldOperationFilter Type(AstFilterField field, BsonType type)
245+
{
246+
return new AstFieldOperationFilter(field, new AstTypeFilterOperation(type));
247+
}
238248
#endregion
239249

240250
// public properties

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/ObjectMethod.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,20 @@ internal static class ObjectMethod
2121
{
2222
// private static fields
2323
private static readonly MethodInfo __equals;
24+
private static readonly MethodInfo __getType;
2425
private static readonly MethodInfo __toString;
2526

2627
// static constructor
2728
static ObjectMethod()
2829
{
2930
__equals = ReflectionInfo.Method((object o, object obj) => o.Equals(obj));
31+
__getType = ReflectionInfo.Method((object o) => o.GetType());
3032
__toString = ReflectionInfo.Method((object o) => o.ToString());
3133
}
3234

3335
// public properties
3436
public static new MethodInfo Equals => __equals;
37+
public static new MethodInfo GetType => __getType;
3538
public static new MethodInfo ToString => __toString;
3639
}
3740
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ internal static class BinaryExpressionToAggregationExpressionTranslator
3030
{
3131
public static AggregationExpression Translate(TranslationContext context, BinaryExpression expression)
3232
{
33+
if (GetTypeComparisonExpressionToAggregationExpressionTranslator.CanTranslate(expression))
34+
{
35+
return GetTypeComparisonExpressionToAggregationExpressionTranslator.Translate(context, expression);
36+
}
37+
3338
if (StringGetCharsComparisonExpressionToAggregationExpressionTranslator.CanTranslate(expression, out var getCharsExpression))
3439
{
3540
return StringGetCharsComparisonExpressionToAggregationExpressionTranslator.Translate(context, expression, getCharsExpression);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public static AggregationExpression Translate(TranslationContext context, Expres
7777
return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression);
7878
case ExpressionType.Parameter:
7979
return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression);
80+
case ExpressionType.TypeIs:
81+
return TypeIsExpressionToAggregationExpressionTranslator.Translate(context, (TypeBinaryExpression)expression);
8082
}
8183

8284
throw new ExpressionNotSupportedException(expression);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq.Expressions;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
23+
24+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
25+
{
26+
internal static class GetTypeComparisonExpressionToAggregationExpressionTranslator
27+
{
28+
// public static methods
29+
public static bool CanTranslate(BinaryExpression expression)
30+
{
31+
return CanTranslate(expression, out _, out _);
32+
}
33+
34+
public static AggregationExpression Translate(TranslationContext context, BinaryExpression expression)
35+
{
36+
if (CanTranslate(expression, out var getTypeMethodCallExpression, out var comparandType))
37+
{
38+
var objectExpression = getTypeMethodCallExpression.Object;
39+
var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
40+
var nominalType = objectExpression.Type;
41+
var actualType = comparandType;
42+
43+
var discriminatorConvention = objectTranslation.Serializer is ObjectSerializer objectSerializer ?
44+
objectSerializer.DiscriminatorConvention :
45+
BsonSerializer.LookupDiscriminatorConvention(nominalType);
46+
var discriminatorField = AstExpression.GetField(objectTranslation.Ast, discriminatorConvention.ElementName);
47+
var discriminatorValue = discriminatorConvention.GetDiscriminator(nominalType, actualType);
48+
49+
var ast = AstExpression.Eq(discriminatorField, discriminatorValue);
50+
return new AggregationExpression(expression, ast, BooleanSerializer.Instance);
51+
}
52+
53+
throw new ExpressionNotSupportedException(expression);
54+
}
55+
56+
// private static methods
57+
private static bool CanTranslate(BinaryExpression expression, out MethodCallExpression getTypeMethodCallExpression, out Type comparandType)
58+
{
59+
var leftExpression = expression.Left;
60+
var rightExpression = expression.Right;
61+
62+
if (leftExpression is MethodCallExpression methodCallExpression &&
63+
methodCallExpression.Method.Is(ObjectMethod.GetType) &&
64+
expression.NodeType == ExpressionType.Equal &&
65+
rightExpression is ConstantExpression constantExpression)
66+
{
67+
getTypeMethodCallExpression = methodCallExpression;
68+
comparandType = (Type)constantExpression.Value;
69+
return true;
70+
}
71+
72+
getTypeMethodCallExpression = null;
73+
comparandType = null;
74+
return false;
75+
}
76+
77+
}
78+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using System.Linq.Expressions;
18+
using MongoDB.Bson;
19+
using MongoDB.Bson.Serialization;
20+
using MongoDB.Bson.Serialization.Serializers;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
22+
23+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
24+
{
25+
internal static class TypeIsExpressionToAggregationExpressionTranslator
26+
{
27+
// public static methods
28+
public static AggregationExpression Translate(TranslationContext context, TypeBinaryExpression expression)
29+
{
30+
var objectExpression = expression.Expression;
31+
var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
32+
var nominalType = objectExpression.Type;
33+
var actualType = expression.TypeOperand;
34+
35+
var discriminatorConvention = objectTranslation.Serializer is ObjectSerializer objectSerializer ?
36+
objectSerializer.DiscriminatorConvention :
37+
BsonSerializer.LookupDiscriminatorConvention(nominalType);
38+
var discriminatorField = AstExpression.GetField(objectTranslation.Ast, discriminatorConvention.ElementName);
39+
var discriminatorValue = discriminatorConvention.GetDiscriminator(nominalType, actualType);
40+
if (discriminatorValue is BsonArray array)
41+
{
42+
discriminatorValue = array.Last();
43+
}
44+
45+
var ast = AstExpression.Or(
46+
AstExpression.Eq(discriminatorField, discriminatorValue),
47+
AstExpression.And(
48+
AstExpression.IsArray(discriminatorField),
49+
AstExpression.In(discriminatorValue, discriminatorField)));
50+
51+
return new AggregationExpression(expression, ast, BooleanSerializer.Instance);
52+
}
53+
}
54+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ComparisonExpressionToFilterTranslator.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ public static AstFilter Translate(TranslationContext context, BinaryExpression e
5656
return CountComparisonExpressionToFilterTranslator.Translate(context, expression, countExpression, sizeExpression);
5757
}
5858

59+
if (GetTypeComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, rightExpression))
60+
{
61+
return GetTypeComparisonExpressionToFilterTranslator.Translate(context, expression, (MethodCallExpression)leftExpression, (ConstantExpression)rightExpression);
62+
}
63+
5964
if (ModuloComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, rightExpression, out var moduloExpression, out var remainderExpression))
6065
{
6166
return ModuloComparisonExpressionToFilterTranslator.Translate(context, expression, moduloExpression, remainderExpression);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq.Expressions;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
21+
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators;
25+
26+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ExpressionTranslators
27+
{
28+
internal static class GetTypeComparisonExpressionToFilterTranslator
29+
{
30+
// caller is responsible for ensuring constant is on the right
31+
public static bool CanTranslate(Expression leftExpression, Expression rightExpression)
32+
{
33+
return
34+
leftExpression is MethodCallExpression methodCallExpression &&
35+
methodCallExpression.Method.Is(ObjectMethod.GetType) &&
36+
rightExpression is ConstantExpression;
37+
}
38+
39+
public static AstFilter Translate(TranslationContext context, BinaryExpression expression, MethodCallExpression getTypeExpression, Expression typeConstantExpression)
40+
{
41+
var field = ExpressionToFilterFieldTranslator.Translate(context, getTypeExpression.Object);
42+
var nominalType = field.Serializer.ValueType;
43+
var actualType = typeConstantExpression.GetConstantValue<Type>(expression);
44+
45+
var discriminatorConvention = field.Serializer is ObjectSerializer objectSerializer ?
46+
objectSerializer.DiscriminatorConvention :
47+
BsonSerializer.LookupDiscriminatorConvention(nominalType);
48+
var discriminatorField = field.SubField(discriminatorConvention.ElementName, BsonValueSerializer.Instance);
49+
var discriminatorValue = discriminatorConvention.GetDiscriminator(nominalType, actualType);
50+
51+
if (discriminatorValue.IsBsonArray)
52+
{
53+
var discriminatorValues = discriminatorValue.AsBsonArray;
54+
var filters = new AstFilter[discriminatorValues.Count + 1];
55+
filters[0] = AstFilter.Size(discriminatorField, discriminatorValues.Count); // don't match subclasses
56+
for (var i = 0; i < discriminatorValues.Count; i++)
57+
{
58+
var discriminatorItemField = discriminatorField.SubField(i.ToString(), BsonValueSerializer.Instance);
59+
filters[i + 1] = AstFilter.Eq(discriminatorItemField, discriminatorValues[i]);
60+
}
61+
62+
return AstFilter.And(filters);
63+
64+
}
65+
else
66+
{
67+
var discriminatorFieldElementZero = discriminatorField.SubField("0", BsonValueSerializer.Instance);
68+
return AstFilter.And(
69+
AstFilter.NotExists(discriminatorFieldElementZero), // required to avoid false matches on subclasses with hierarchical discriminators
70+
AstFilter.Eq(discriminatorField, discriminatorValue));
71+
}
72+
}
73+
}
74+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/TypeIsExpressionToFilterTranslator.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System.Linq;
1617
using System.Linq.Expressions;
18+
using MongoDB.Bson;
1719
using MongoDB.Bson.Serialization;
1820
using MongoDB.Bson.Serialization.Serializers;
1921
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
@@ -29,14 +31,20 @@ public static AstFilter Translate(TranslationContext context, TypeBinaryExpressi
2931
{
3032
var fieldExpression = expression.Expression;
3133
var field = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression);
32-
3334
var nominalType = fieldExpression.Type;
3435
var actualType = expression.TypeOperand;
35-
var discriminatorConvention = BsonSerializer.LookupDiscriminatorConvention(actualType);
36+
37+
var discriminatorConvention = field.Serializer is ObjectSerializer objectSerializer ?
38+
objectSerializer.DiscriminatorConvention :
39+
BsonSerializer.LookupDiscriminatorConvention(actualType);
3640
var discriminatorField = field.SubField(discriminatorConvention.ElementName, BsonValueSerializer.Instance);
37-
var discriminator = discriminatorConvention.GetDiscriminator(nominalType, actualType);
41+
var discriminatorValue = discriminatorConvention.GetDiscriminator(nominalType, actualType);
42+
if (discriminatorValue is BsonArray array)
43+
{
44+
discriminatorValue = array.Last();
45+
}
3846

39-
return AstFilter.Eq(discriminatorField, discriminator);
47+
return AstFilter.Eq(discriminatorField, discriminatorValue); // will match subclasses also
4048
}
4149

4250
throw new ExpressionNotSupportedException(expression);

0 commit comments

Comments
 (0)