Skip to content

Commit 00ac83e

Browse files
committed
CSHARP-4066: Only use regex filters against string properties that are serialized as strings.
1 parent c0c973a commit 00ac83e

File tree

8 files changed

+144
-30
lines changed

8 files changed

+144
-30
lines changed

src/MongoDB.Driver/Linq/ExpressionNotSupportedException.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,20 @@ private static string FormatMessage(Expression expression)
2929
return $"Expression not supported: {expression}.";
3030
}
3131

32+
private static string FormatMessage(Expression expression, string because)
33+
{
34+
return $"Expression not supported: {expression} because {because}.";
35+
}
36+
3237
private static string FormatMessage(Expression expression, Expression containingExpression)
3338
{
3439
return $"Expression not supported: {expression} in {containingExpression}.";
3540
}
41+
42+
private static string FormatMessage(Expression expression, Expression containingExpression, string because)
43+
{
44+
return $"Expression not supported: {expression} in {containingExpression} because {because}.";
45+
}
3646
#endregion
3747

3848
// constructors
@@ -54,6 +64,16 @@ public ExpressionNotSupportedException(Expression expression)
5464
{
5565
}
5666

67+
/// <summary>
68+
/// Initializes an instance of an ExpressionNotSupportedException.
69+
/// </summary>
70+
/// <param name="expression">The expression.</param>
71+
/// <param name="because">The reason.</param>
72+
public ExpressionNotSupportedException(Expression expression, string because)
73+
: base(FormatMessage(expression, because))
74+
{
75+
}
76+
5777
/// <summary>
5878
/// Initializes an instance of an ExpressionNotSupportedException.
5979
/// </summary>
@@ -63,5 +83,16 @@ public ExpressionNotSupportedException(Expression expression, Expression contain
6383
: base(FormatMessage(expression, containingExpression))
6484
{
6585
}
86+
87+
/// <summary>
88+
/// Initializes an instance of an ExpressionNotSupportedException.
89+
/// </summary>
90+
/// <param name="expression">The expression.</param>
91+
/// <param name="containingExpression">The containing expression.</param>
92+
/// <param name="because">The reason.</param>
93+
public ExpressionNotSupportedException(Expression expression, Expression containingExpression, string because)
94+
: base(FormatMessage(expression, containingExpression, because))
95+
{
96+
}
6697
}
6798
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Filters/AstFieldOperationFilter.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using MongoDB.Bson;
18+
using MongoDB.Bson.Serialization;
1719
using MongoDB.Driver.Core.Misc;
1820
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors;
1921

@@ -28,6 +30,14 @@ public AstFieldOperationFilter(AstFilterField field, AstFilterOperation operatio
2830
{
2931
_field = Ensure.IsNotNull(field, nameof(field));
3032
_operation = Ensure.IsNotNull(operation, nameof(operation));
33+
34+
if (operation.NodeType == AstNodeType.RegexFilterOperation &&
35+
field.Serializer is IRepresentationConfigurable representationConfigurable &&
36+
representationConfigurable.Representation != BsonType.String)
37+
{
38+
// normally an ExpressionNotSupported should have been thrown before reaching here
39+
throw new ArgumentException($"Field must be represented as a string for regex filter operations: {field.Path}", nameof(field));
40+
}
3141
}
3242

3343
public new AstFilterField Field => _field;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ComparisonExpressionToFilterTranslator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ public static AstFilter Translate(TranslationContext context, BinaryExpression e
6161
return ModuloComparisonExpressionToFilterTranslator.Translate(context, expression, moduloExpression, remainderExpression);
6262
}
6363

64-
if (StringExpressionToRegexFilterTranslator.CanTranslateComparisonExpression(leftExpression, comparisonOperator, rightExpression))
64+
if (StringExpressionToRegexFilterTranslator.TryTranslateComparisonExpression(context, expression, leftExpression, comparisonOperator, rightExpression, out var filter))
6565
{
66-
return StringExpressionToRegexFilterTranslator.TranslateComparisonExpression(context, expression, leftExpression, comparisonOperator, rightExpression);
66+
return filter;
6767
}
6868

6969
var comparand = rightExpression.GetConstantValue<object>(containingExpression: expression);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/StringExpressionToRegexFilterTranslator.cs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
using System.Text;
2222
using System.Text.RegularExpressions;
2323
using MongoDB.Bson;
24+
using MongoDB.Bson.Serialization;
2425
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
2526
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
2627
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
@@ -147,33 +148,33 @@ public static bool CanTranslate(Expression expression)
147148
return false;
148149
}
149150

150-
// caller is responsible for ensuring constant is on the right
151-
public static bool CanTranslateComparisonExpression(Expression leftExpression, AstComparisonFilterOperator comparisonOperator, Expression rightExpression)
151+
public static bool TryTranslate(TranslationContext context, Expression expression, out AstFilter filter)
152152
{
153-
// (int)document.S[i] == c
154-
if (IsGetCharsComparison(leftExpression))
153+
try
155154
{
155+
filter = Translate(context, expression);
156156
return true;
157157
}
158-
159-
// document.S == "abc"
160-
if (IsStringComparison(leftExpression))
158+
catch (ExpressionNotSupportedException)
161159
{
162-
return true;
160+
filter = null;
161+
return false;
163162
}
163+
}
164164

165-
// document.S.IndexOf('a') == n etc...
166-
if (IsStringIndexOfComparison(leftExpression))
165+
// caller is responsible for ensuring constant is on the right
166+
public static bool TryTranslateComparisonExpression(TranslationContext context, Expression expression, Expression leftExpression, AstComparisonFilterOperator comparisonOperator, Expression rightExpression, out AstFilter filter)
167+
{
168+
try
167169
{
170+
filter = TranslateComparisonExpression(context, expression, leftExpression, comparisonOperator, rightExpression);
168171
return true;
169172
}
170-
// document.S.Length == n or document.S.Count() == n
171-
if (IsStringLengthComparison(leftExpression) || IsStringCountComparison(leftExpression))
173+
catch (ExpressionNotSupportedException)
172174
{
173-
return true;
175+
filter = null;
176+
return false;
174177
}
175-
176-
return false;
177178
}
178179

179180
public static AstFilter Translate(TranslationContext context, Expression expression)
@@ -362,7 +363,7 @@ private static AstFilter TranslateGetCharsComparison(TranslationContext context,
362363
var leftConvertExpression = (UnaryExpression)leftExpression;
363364
var leftGetCharsExpression = (MethodCallExpression)leftConvertExpression.Operand;
364365
var fieldExpression = leftGetCharsExpression.Object;
365-
var (field, modifiers) = TranslateField(context, fieldExpression);
366+
var (field, modifiers) = TranslateField(context, expression, fieldExpression);
366367

367368
var indexExpression = leftGetCharsExpression.Arguments[0];
368369
var index = indexExpression.GetConstantValue<int>(containingExpression: expression);
@@ -383,18 +384,25 @@ private static AstFilter TranslateGetCharsComparison(TranslationContext context,
383384
throw new ExpressionNotSupportedException(expression);
384385
}
385386

386-
private static (AstFilterField, Modifiers) TranslateField(TranslationContext context, Expression fieldExpression)
387+
private static (AstFilterField, Modifiers) TranslateField(TranslationContext context, Expression expression, Expression fieldExpression)
387388
{
388389
if (fieldExpression is MethodCallExpression fieldMethodCallExpression &&
389390
fieldMethodCallExpression.Method.IsOneOf(__modifierMethods))
390391
{
391-
var (field, modifiers) = TranslateField(context, fieldMethodCallExpression.Object);
392+
var (field, modifiers) = TranslateField(context, expression, fieldMethodCallExpression.Object);
392393
modifiers = TranslateModifier(modifiers, fieldMethodCallExpression);
393394
return (field, modifiers);
394395
}
395396
else
396397
{
397398
var field = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression);
399+
400+
if (field.Serializer is IRepresentationConfigurable representationConfigurable &&
401+
representationConfigurable.Representation != BsonType.String)
402+
{
403+
throw new ExpressionNotSupportedException(fieldExpression, expression, because: $"field \"{field.Path}\" is not represented as a string");
404+
}
405+
398406
return (field, new Modifiers());
399407
}
400408
}
@@ -424,7 +432,7 @@ private static AstFilter TranslateStartsWithOrContainsOrEndsWith(TranslationCont
424432
var method = expression.Method;
425433
var arguments = expression.Arguments;
426434

427-
var (field, modifiers) = TranslateField(context, expression.Object);
435+
var (field, modifiers) = TranslateField(context, expression, expression.Object);
428436
var value = arguments[0].GetConstantValue<string>(containingExpression: expression);
429437
if (method.IsOneOf(StringMethod.StartsWithWithComparisonType, StringMethod.EndsWithWithComparisonType))
430438
{
@@ -466,7 +474,7 @@ bool IsImpossibleMatch(Modifiers modifiers, string value)
466474

467475
private static AstFilter TranslateStringComparison(TranslationContext context, Expression expression, Expression leftExpression, AstComparisonFilterOperator comparisonOperator, Expression rightExpression)
468476
{
469-
var (field, modifiers) = TranslateField(context, leftExpression);
477+
var (field, modifiers) = TranslateField(context, expression, leftExpression);
470478
var comparand = rightExpression.GetConstantValue<string>(containingExpression: expression);
471479

472480
if (comparisonOperator == AstComparisonFilterOperator.Eq || comparisonOperator == AstComparisonFilterOperator.Ne)
@@ -513,7 +521,7 @@ private static AstFilter TranslateStringIndexOfComparison(TranslationContext con
513521
var arguments = leftMethodCallExpression.Arguments;
514522

515523
var fieldExpression = leftMethodCallExpression.Object;
516-
var (field, modifiers) = TranslateField(context, fieldExpression);
524+
var (field, modifiers) = TranslateField(context, expression, fieldExpression);
517525

518526
var startIndex = 0;
519527
if (method.IsOneOf(__indexOfWithStartIndexMethods))
@@ -653,7 +661,7 @@ private static AstFilter TranslateStringLengthComparison(TranslationContext cont
653661
throw new ExpressionNotSupportedException(expression);
654662
}
655663

656-
var (field, modifiers) = TranslateField(context, fieldExpression);
664+
var (field, modifiers) = TranslateField(context, expression, fieldExpression);
657665

658666
var comparand = rightExpression.GetConstantValue<int>(containingExpression: expression);
659667
var pattern = comparisonOperator switch

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsMethodToFilterTranslator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ internal static class ContainsMethodToFilterTranslator
2828
{
2929
public static AstFilter Translate(TranslationContext context, MethodCallExpression expression)
3030
{
31-
if (StringExpressionToRegexFilterTranslator.CanTranslate(expression))
31+
if (StringExpressionToRegexFilterTranslator.TryTranslate(context, expression, out var filter))
3232
{
33-
return StringExpressionToRegexFilterTranslator.Translate(context, expression);
33+
return filter;
3434
}
3535

3636
var method = expression.Method;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ private static AstFilter Translate(TranslationContext context, Expression expres
6565
(fieldExpression, valueExpression) = (expression1, expression2);
6666
}
6767

68-
if (StringExpressionToRegexFilterTranslator.CanTranslateComparisonExpression(fieldExpression, AstComparisonFilterOperator.Eq, valueExpression))
68+
if (StringExpressionToRegexFilterTranslator.TryTranslateComparisonExpression(context, expression, fieldExpression, AstComparisonFilterOperator.Eq, valueExpression, out var filter))
6969
{
70-
return StringExpressionToRegexFilterTranslator.TranslateComparisonExpression(context, expression, fieldExpression, AstComparisonFilterOperator.Eq, valueExpression);
70+
return filter;
7171
}
7272

7373
var field = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/IsMatchMethodToFilterTranslator.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using System.Linq.Expressions;
1717
using System.Text.RegularExpressions;
1818
using MongoDB.Bson;
19+
using MongoDB.Bson.Serialization;
1920
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
2021
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2122
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
@@ -30,7 +31,7 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi
3031
{
3132
if (IsMatchMethod(expression, out var inputExpression, out var regularExpression))
3233
{
33-
return Translate(context, inputExpression, regularExpression);
34+
return Translate(context, expression, inputExpression, regularExpression);
3435
}
3536

3637
throw new ExpressionNotSupportedException(expression);
@@ -92,10 +93,17 @@ private static bool IsMatchMethod(MethodCallExpression expression, out Expressio
9293
return false;
9394
}
9495

95-
private static AstFilter Translate(TranslationContext context, Expression inputExpression, Regex regex)
96+
private static AstFilter Translate(TranslationContext context, Expression expression, Expression inputExpression, Regex regex)
9697
{
9798
var inputFieldAst = ExpressionToFilterFieldTranslator.Translate(context, inputExpression);
9899
var regularExpression = new BsonRegularExpression(regex);
100+
101+
if (inputFieldAst.Serializer is IRepresentationConfigurable representationConfigurable &&
102+
representationConfigurable.Representation != BsonType.String)
103+
{
104+
throw new ExpressionNotSupportedException(inputExpression, expression, because: $"field \"{inputFieldAst.Path}\" is not represented as a string");
105+
}
106+
99107
return AstFilter.Regex(inputFieldAst, regularExpression.Pattern, regularExpression.Options);
100108
}
101109
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using FluentAssertions;
17+
using MongoDB.Bson;
18+
using MongoDB.Bson.Serialization.Attributes;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
22+
{
23+
public class CSharp4066Tests : Linq3IntegrationTest
24+
{
25+
[Fact]
26+
public void String_comparison_in_filter_should_use_custom_serializer()
27+
{
28+
var collection = GetCollection<C>();
29+
collection.Database.DropCollection(collection.CollectionNamespace.CollectionName);
30+
31+
var id = "0102030405060708090a0b0c";
32+
collection.InsertMany(
33+
new[]
34+
{
35+
new C { Id = id, X = 1 },
36+
new C { Id = "000000000000000000000000", X = 2 }
37+
});
38+
39+
var find = collection.Find(x => x.Id == id);
40+
41+
var rendered = find.ToString();
42+
rendered.Should().Be("find({ \"_id\" : ObjectId(\"0102030405060708090a0b0c\") })");
43+
44+
var results = find.ToList();
45+
results.Count.Should().Be(1);
46+
results[0].Id.Should().Be(id);
47+
results[0].X.Should().Be(1);
48+
}
49+
50+
public class C
51+
{
52+
[BsonRepresentation(BsonType.ObjectId)]
53+
public string Id { get; set; }
54+
public int X { get; set; }
55+
}
56+
}
57+
}

0 commit comments

Comments
 (0)