Skip to content

Commit 57bae35

Browse files
authored
CSHARP-2003: $bitsAnySet etc should work with Enums. (#839)
1 parent 2610660 commit 57bae35

File tree

4 files changed

+223
-8
lines changed

4 files changed

+223
-8
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/BitMaskComparisonExpressionToFilterTranslator.cs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,49 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilter
2424
{
2525
internal static class BitMaskComparisonExpressionToFilterTranslator
2626
{
27-
public static bool CanTranslate(Expression leftExpression)
27+
public static bool CanTranslate(Expression leftExpression, Expression rightExpression)
2828
{
29-
return
30-
leftExpression is BinaryExpression leftBinaryExpression &&
31-
leftBinaryExpression.NodeType == ExpressionType.And;
29+
return CanTranslate(leftExpression, rightExpression, out _);
30+
}
31+
32+
public static bool CanTranslate(Expression leftExpression, Expression rightExpression, out BinaryExpression leftBinaryExpression)
33+
{
34+
if (rightExpression.NodeType == ExpressionType.Constant)
35+
{
36+
// a leftExpression with an & operation with an enum looks like:
37+
// Convert(Convert((Convert(x.E, Int32) & mask), E), Int32)
38+
if (leftExpression is UnaryExpression outerToUnderlyingTypeConvertExpression &&
39+
outerToUnderlyingTypeConvertExpression.NodeType == ExpressionType.Convert &&
40+
outerToUnderlyingTypeConvertExpression.Operand is UnaryExpression innerToEnumConvertExpression &&
41+
innerToEnumConvertExpression.NodeType == ExpressionType.Convert &&
42+
innerToEnumConvertExpression.Operand is BinaryExpression innerBinaryExpression &&
43+
innerBinaryExpression.NodeType == ExpressionType.And &&
44+
innerBinaryExpression.Left is UnaryExpression innerToUnderlyingTypeConvertExpression &&
45+
innerToUnderlyingTypeConvertExpression.NodeType == ExpressionType.Convert)
46+
{
47+
var enumType = innerToEnumConvertExpression.Type;
48+
if (enumType.IsEnum)
49+
{
50+
var underlyingType = enumType.GetEnumUnderlyingType();
51+
if (outerToUnderlyingTypeConvertExpression.Type == underlyingType &&
52+
innerToEnumConvertExpression.Type == enumType &&
53+
innerToUnderlyingTypeConvertExpression.Type == underlyingType)
54+
{
55+
leftExpression = innerBinaryExpression; // Convert(x.E, Int32) & mask
56+
}
57+
}
58+
}
59+
60+
leftBinaryExpression = leftExpression as BinaryExpression;
61+
if (leftBinaryExpression != null &&
62+
leftBinaryExpression.NodeType == ExpressionType.And)
63+
{
64+
return true;
65+
}
66+
}
67+
68+
leftBinaryExpression = null;
69+
return false;
3270
}
3371

3472
// caller is responsible for ensuring constant is on the right
@@ -39,10 +77,9 @@ public static AstFilter Translate(
3977
AstComparisonFilterOperator comparisonOperator,
4078
Expression rightExpression)
4179
{
42-
if (leftExpression is BinaryExpression leftBinaryExpression &&
43-
leftBinaryExpression.NodeType == ExpressionType.And)
80+
if (CanTranslate(leftExpression, rightExpression, out var leftBinaryExpression))
4481
{
45-
var fieldExpression = leftBinaryExpression.Left;
82+
var fieldExpression = ConvertHelper.RemoveConvertToEnumUnderlyingType(leftBinaryExpression.Left);
4683
var field = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression);
4784

4885
var bitMaskExpression = leftBinaryExpression.Right;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ComparisonExpressionToFilterTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public static AstFilter Translate(TranslationContext context, BinaryExpression e
4141
return ArrayLengthComparisonExpressionToFilterTranslator.Translate(context, expression, arrayLengthExpression, sizeExpression);
4242
}
4343

44-
if (BitMaskComparisonExpressionToFilterTranslator.CanTranslate(leftExpression))
44+
if (BitMaskComparisonExpressionToFilterTranslator.CanTranslate(leftExpression, rightExpression))
4545
{
4646
return BitMaskComparisonExpressionToFilterTranslator.Translate(context, expression, leftExpression, comparisonOperator, rightExpression);
4747
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq;
18+
using FluentAssertions;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
22+
{
23+
public class CSharp2003Tests : Linq3IntegrationTest
24+
{
25+
[Fact]
26+
public void Find_BitsAllClear_should_work()
27+
{
28+
var collection = CreateCollection();
29+
var mask = E.E2 | E.E4;
30+
var find = collection.Find(x => (x.E & mask) == 0);
31+
32+
var filter = TranslateFilter(collection, find);
33+
filter.Should().Be("{ E : { $bitsAllClear : 6 } }");
34+
35+
var results = find.ToList().OrderBy(x => x.Id).ToList();
36+
results.Select(x => x.Id).Should().Equal(1, 8);
37+
}
38+
39+
[Fact]
40+
public void Find_BitsAllSet_should_work()
41+
{
42+
var collection = CreateCollection();
43+
var mask = E.E2 | E.E4;
44+
var find = collection.Find(x => (x.E & mask) == mask);
45+
46+
var filter = TranslateFilter(collection, find);
47+
filter.Should().Be("{ E : { $bitsAllSet : 6 } }");
48+
49+
var results = find.ToList().OrderBy(x => x.Id).ToList();
50+
results.Select(x => x.Id).Should().Equal(6);
51+
}
52+
53+
[Fact]
54+
public void Find_BitsAnyClear_should_work()
55+
{
56+
var collection = CreateCollection();
57+
var mask = E.E2 | E.E4;
58+
var find = collection.Find(x => (x.E & mask) != mask);
59+
60+
var filter = TranslateFilter(collection, find);
61+
filter.Should().Be("{ E : { $bitsAnyClear : 6 } }");
62+
63+
var results = find.ToList().OrderBy(x => x.Id).ToList();
64+
results.Select(x => x.Id).Should().Equal(1, 2, 4, 8);
65+
}
66+
67+
[Fact]
68+
public void Find_BitsAnySet_should_work()
69+
{
70+
var collection = CreateCollection();
71+
var mask = E.E2 | E.E4;
72+
var find = collection.Find(x => (x.E & mask) != 0);
73+
74+
var filter = TranslateFilter(collection, find);
75+
filter.Should().Be("{ E : { $bitsAnySet : 6 } }");
76+
77+
var results = find.ToList().OrderBy(x => x.Id).ToList();
78+
results.Select(x => x.Id).Should().Equal(2, 4, 6);
79+
}
80+
81+
[Fact]
82+
public void Where_BitsAllClear_should_work()
83+
{
84+
var collection = CreateCollection();
85+
var mask = E.E2 | E.E4;
86+
var queryable = collection.AsQueryable().Where(x => (x.E & mask) == 0);
87+
88+
var stages = Translate(collection, queryable);
89+
AssertStages(stages, "{ $match : { E : { $bitsAllClear : 6 } } }");
90+
91+
var results = queryable.ToList().OrderBy(x => x.Id).ToList();
92+
results.Select(x => x.Id).Should().Equal(1, 8);
93+
}
94+
95+
[Fact]
96+
public void Where_BitsAllSet_should_work()
97+
{
98+
var collection = CreateCollection();
99+
var mask = E.E2 | E.E4;
100+
var queryable = collection.AsQueryable().Where(x => (x.E & mask) == mask);
101+
102+
var stages = Translate(collection, queryable);
103+
AssertStages(stages, "{ $match : { E : { $bitsAllSet : 6 } } }");
104+
105+
var results = queryable.ToList().OrderBy(x => x.Id).ToList();
106+
results.Select(x => x.Id).Should().Equal(6);
107+
}
108+
109+
[Fact]
110+
public void Where_BitsAnyClear_should_work()
111+
{
112+
var collection = CreateCollection();
113+
var mask = E.E2 | E.E4;
114+
var queryable = collection.AsQueryable().Where(x => (x.E & mask) != mask);
115+
116+
var stages = Translate(collection, queryable);
117+
AssertStages(stages, "{ $match : { E : { $bitsAnyClear : 6 } } }");
118+
119+
var results = queryable.ToList().OrderBy(x => x.Id).ToList();
120+
results.Select(x => x.Id).Should().Equal(1, 2, 4, 8);
121+
}
122+
123+
[Fact]
124+
public void Where_BitsAnySet_should_work()
125+
{
126+
var collection = CreateCollection();
127+
var mask = E.E2 | E.E4;
128+
var queryable = collection.AsQueryable().Where(x => (x.E & mask) != 0);
129+
130+
var stages = Translate(collection, queryable);
131+
AssertStages(stages, "{ $match : { E : { $bitsAnySet : 6 } } }");
132+
133+
var results = queryable.ToList().OrderBy(x => x.Id).ToList();
134+
results.Select(x => x.Id).Should().Equal(2, 4, 6);
135+
}
136+
137+
private IMongoCollection<C> CreateCollection()
138+
{
139+
var collection = GetCollection<C>();
140+
141+
var documents = new[]
142+
{
143+
new C { Id = 1, E = E.E1 },
144+
new C { Id = 2, E = E.E2 },
145+
new C { Id = 4, E = E.E4 },
146+
new C { Id = 6, E = E.E2 | E.E4 },
147+
new C { Id = 8, E = E.E8 }
148+
};
149+
CreateCollection(collection, documents);
150+
151+
return collection;
152+
}
153+
154+
[Flags]
155+
private enum E
156+
{
157+
E1 = 1,
158+
E2 = 2,
159+
E4 = 4,
160+
E8 = 8
161+
}
162+
163+
private class C
164+
{
165+
public int Id { get; set; }
166+
public E E;
167+
}
168+
}
169+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Linq3IntegrationTest.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,14 @@ protected List<BsonDocument> Translate<TDocument, TResult>(IQueryable<TResult> q
106106
var stages = executableQuery.Pipeline.Stages;
107107
return stages.Select(s => s.Render().AsBsonDocument).ToList();
108108
}
109+
110+
protected BsonDocument TranslateFilter<TDocument>(IMongoCollection<TDocument> collection, IFindFluent<TDocument, TDocument> find)
111+
{
112+
var filterDefinition = find.Filter;
113+
var documentSerializer = collection.DocumentSerializer;
114+
var serializerRegistry = BsonSerializer.SerializerRegistry;
115+
var linqProvider = collection.Database.Client.Settings.LinqProvider;
116+
return filterDefinition.Render(documentSerializer, serializerRegistry, linqProvider);
117+
}
109118
}
110119
}

0 commit comments

Comments
 (0)