Skip to content

Commit d3a0baa

Browse files
damiengrstam
authored andcommitted
CSHARP-5171: Add IReadOnlyDictionary to existing IDictionary indexer translation.
1 parent 74fba13 commit d3a0baa

File tree

5 files changed

+231
-41
lines changed

5 files changed

+231
-41
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
/* Copyright 2010-present MongoDB Inc.
2-
*
3-
* Licensed under the Apache License, Version 2.0 (the "License");
4-
* you may not use this file except in compliance with the License.
5-
* You may obtain a copy of the License at
6-
*
7-
* http://www.apache.org/licenses/LICENSE-2.0
8-
*
9-
* Unless required by applicable law or agreed to in writing, software
10-
* distributed under the License is distributed on an "AS IS" BASIS,
11-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
* See the License for the specific language governing permissions and
13-
* limitations under the License.
14-
*/
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
1515

1616
using System;
1717
using System.Collections.Generic;
@@ -22,6 +22,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2222
{
2323
internal static class TypeExtensions
2424
{
25+
private static readonly Type[] __dictionaryInterfaces =
26+
{
27+
typeof(IDictionary<,>),
28+
typeof(IReadOnlyDictionary<,>)
29+
};
30+
2531
private static Type[] __tupleTypeDefinitions =
2632
{
2733
typeof(Tuple<>),
@@ -84,11 +90,11 @@ public static bool Implements(this Type type, Type @interface)
8490
return false;
8591
}
8692

87-
public static bool ImplementsIDictionary(this Type type, out Type keyType, out Type valueType)
93+
public static bool ImplementsDictionaryInterface(this Type type, out Type keyType, out Type valueType)
8894
{
89-
if (TryGetIDictionaryGenericInterface(type, out var idictionaryType))
95+
if (TryGetGenericInterface(type, __dictionaryInterfaces, out var dictionaryInterface))
9096
{
91-
var genericArguments = idictionaryType.GetGenericArguments();
97+
var genericArguments = dictionaryInterface.GetGenericArguments();
9298
keyType = genericArguments[0];
9399
valueType = genericArguments[1];
94100
return true;
@@ -255,28 +261,14 @@ public static bool IsValueTuple(this Type type)
255261
type.IsConstructedGenericType &&
256262
type.GetGenericTypeDefinition() is var typeDefinition &&
257263
__valueTupleTypeDefinitions.Contains(typeDefinition);
258-
259264
}
260265

261-
public static bool TryGetIDictionaryGenericInterface(this Type type, out Type idictionaryGenericInterface)
266+
public static bool TryGetGenericInterface(this Type type, Type[] interfaceDefinitions, out Type genericInterface)
262267
{
263-
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IDictionary<,>))
264-
{
265-
idictionaryGenericInterface = type;
266-
return true;
267-
}
268-
269-
foreach (var interfaceType in type.GetInterfaces())
270-
{
271-
if (interfaceType.IsGenericType && interfaceType.GetGenericTypeDefinition() == typeof(IDictionary<,>))
272-
{
273-
idictionaryGenericInterface = interfaceType;
274-
return true;
275-
}
276-
}
277-
278-
idictionaryGenericInterface = null;
279-
return false;
268+
genericInterface = type.IsConstructedGenericType && interfaceDefinitions.Contains(type.GetGenericTypeDefinition())
269+
? type
270+
: type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && interfaceDefinitions.Contains(i.GetGenericTypeDefinition()));
271+
return genericInterface != null;
280272
}
281273

282274
public static bool TryGetIEnumerableGenericInterface(this Type type, out Type ienumerableGenericInterface)

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/IDictionaryMethod.cs renamed to src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryMethod.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection
2020
{
21-
internal static class IDictionaryMethod
21+
internal static class DictionaryMethod
2222
{
2323
// public static methods
2424
public static bool IsGetItemWithStringMethod(MethodInfo method)
@@ -29,7 +29,7 @@ public static bool IsGetItemWithStringMethod(MethodInfo method)
2929
method.GetParameters() is var parameters &&
3030
parameters.Length == 1 &&
3131
parameters[0].ParameterType == typeof(string) &&
32-
method.DeclaringType.ImplementsIDictionary(out var keyType, out var valueType) &&
32+
method.DeclaringType.ImplementsDictionaryInterface(out var keyType, out var valueType) &&
3333
keyType == typeof(string) &&
3434
method.ReturnType == valueType;
3535
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/GetItemMethodToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public static AggregationExpression Translate(TranslationContext context, Expres
5555
return TranslateIListGetItemWithInt(context, expression, sourceExpression, arguments[0]);
5656
}
5757

58-
if (IDictionaryMethod.IsGetItemWithStringMethod(method))
58+
if (DictionaryMethod.IsGetItemWithStringMethod(method))
5959
{
6060
return TranslateIDictionaryGetItemWithString(context, expression, sourceExpression, arguments[0]);
6161
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/GetItemMethodToFilterFieldTranslator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ public static AstFilterField Translate(TranslationContext context, Expression ex
5353
return TranslateIListGetItemWithInt(context, expression, fieldExpression, arguments[0]);
5454
}
5555

56-
if (IDictionaryMethod.IsGetItemWithStringMethod(method))
56+
if (DictionaryMethod.IsGetItemWithStringMethod(method))
5757
{
58-
return TranslateIDictionaryGetItemWithString(context, expression, fieldExpression, arguments[0]);
58+
return TranslateDictionaryGetItemWithString(context, expression, fieldExpression, arguments[0]);
5959
}
6060

6161
throw new ExpressionNotSupportedException(expression);
@@ -80,7 +80,7 @@ private static AstFilterField TranslateIListGetItemWithInt(TranslationContext co
8080
return ArrayIndexExpressionToFilterFieldTranslator.Translate(context, expression, fieldExpression, indexExpression);
8181
}
8282

83-
private static AstFilterField TranslateIDictionaryGetItemWithString(TranslationContext context, Expression expression, Expression fieldExpression, Expression keyExpression)
83+
private static AstFilterField TranslateDictionaryGetItemWithString(TranslationContext context, Expression expression, Expression fieldExpression, Expression keyExpression)
8484
{
8585
var field = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression);
8686
var key = keyExpression.GetConstantValue<string>(containingExpression: expression);
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Collections.ObjectModel;
19+
using System.Linq;
20+
using System.Linq.Expressions;
21+
using FluentAssertions;
22+
using MongoDB.Driver.Linq;
23+
using MongoDB.TestHelpers.XunitExtensions;
24+
using Xunit;
25+
26+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
27+
{
28+
public class CSharp5171Tests : Linq3IntegrationTest
29+
{
30+
[Theory]
31+
[ParameterAttributeData]
32+
public void Select_ReadOnlyDictionary_item_with_string_using_compiler_generated_expression_should_work(
33+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
34+
{
35+
var collection = GetCollection(linqProvider);
36+
37+
var queryable = collection.AsQueryable()
38+
.Select(x => x.Dictionary["a"]);
39+
40+
var stages = Translate(collection, queryable);
41+
var expectedStage = linqProvider == LinqProvider.V2 ?
42+
"{ $project : { a : '$Dictionary.a', _id : 0 } }" :
43+
"{ $project : { _v : '$Dictionary.a', _id : 0 } }";
44+
AssertStages(stages, expectedStage);
45+
46+
var results = queryable.ToList();
47+
results.Should().Equal(1, 0);
48+
}
49+
50+
[Theory]
51+
[ParameterAttributeData]
52+
public void Select_ReadOnlyDictionary_item_with_string_using_call_to_get_item_should_work(
53+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
54+
{
55+
var collection = GetCollection(linqProvider);
56+
var x = Expression.Parameter(typeof(C), "x");
57+
var body = Expression.Call(
58+
Expression.Property(x, typeof(C).GetProperty("Dictionary")!),
59+
typeof(IReadOnlyDictionary<string, int>).GetProperty("Item")!.GetGetMethod(),
60+
Expression.Constant("a"));
61+
var selector = Expression.Lambda<Func<C, int>>(body, [x]);
62+
63+
var queryable = collection.AsQueryable()
64+
.Select(selector);
65+
66+
var stages = Translate(collection, queryable);
67+
var expectedStage = linqProvider == LinqProvider.V2 ?
68+
"{ $project : { a : '$Dictionary.a', _id : 0 } }" :
69+
"{ $project : { _v : '$Dictionary.a', _id : 0 } }";
70+
AssertStages(stages, expectedStage);
71+
72+
var results = queryable.ToList();
73+
results.Should().Equal(1, 0);
74+
}
75+
76+
[Theory]
77+
[ParameterAttributeData]
78+
public void Select_ReadOnlyDictionary_item_with_string_using_MakeIndex_should_work(
79+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
80+
{
81+
var collection = GetCollection(linqProvider);
82+
var x = Expression.Parameter(typeof(C), "x");
83+
var body = Expression.MakeIndex(
84+
Expression.Property(x, typeof(C).GetProperty("Dictionary")!),
85+
typeof(IReadOnlyDictionary<string, int>).GetProperty("Item"),
86+
new Expression[] {Expression.Constant("a")});
87+
var selector = Expression.Lambda<Func<C, int>>(body, [x]);
88+
89+
var queryable = collection.AsQueryable()
90+
.Select(selector);
91+
92+
if (linqProvider == LinqProvider.V2)
93+
{
94+
var exception = Record.Exception(() => Translate(collection, queryable));
95+
exception.Should().BeOfType<NotSupportedException>();
96+
}
97+
else
98+
{
99+
var stages = Translate(collection, queryable);
100+
AssertStages(stages, "{ $project : { _v : '$Dictionary.a', _id : 0 } }");
101+
102+
var results = queryable.ToList();
103+
results.Should().Equal(1, 0);
104+
}
105+
}
106+
107+
[Theory]
108+
[ParameterAttributeData]
109+
public void Where_ReadOnlyDictionary_item_with_string_using_compiler_generated_expression_should_work(
110+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
111+
{
112+
var collection = GetCollection(linqProvider);
113+
114+
var queryable = collection.AsQueryable()
115+
.Where(x => x.Dictionary["a"] == 1);
116+
117+
var stages = Translate(collection, queryable);
118+
AssertStages(stages, "{ $match : { 'Dictionary.a' : 1 } }");
119+
120+
var results = queryable.ToList();
121+
results.Select(x => x.Id).Should().Equal(1);
122+
}
123+
124+
[Theory]
125+
[ParameterAttributeData]
126+
public void Where_ReadOnlyDictionary_item_with_string_using_call_to_get_item_should_work(
127+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
128+
{
129+
var collection = GetCollection(linqProvider);
130+
var x = Expression.Parameter(typeof(C), "x");
131+
var body = Expression.Equal(
132+
Expression.Call(
133+
Expression.Property(x, typeof(C).GetProperty("Dictionary")!),
134+
typeof(IReadOnlyDictionary<string, int>).GetProperty("Item")!.GetGetMethod(),
135+
Expression.Constant("a")),
136+
Expression.Constant(1));
137+
var predicate = Expression.Lambda<Func<C, bool>>(body, [x]);
138+
139+
var queryable = collection.AsQueryable()
140+
.Where(predicate);
141+
142+
var stages = Translate(collection, queryable);
143+
AssertStages(stages, "{ $match : { 'Dictionary.a' : 1 } }");
144+
145+
var results = queryable.ToList();
146+
results.Select(r => r.Id).Should().Equal(1);
147+
}
148+
149+
[Theory]
150+
[ParameterAttributeData]
151+
public void Where_ReadOnlyDictionary_item_with_string_using_MakeIndex_should_work(
152+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
153+
{
154+
var collection = GetCollection(linqProvider);
155+
var x = Expression.Parameter(typeof(C), "x");
156+
var body = Expression.Equal(
157+
Expression.MakeIndex(
158+
Expression.Property(x, typeof(C).GetProperty("Dictionary")!),
159+
typeof(IReadOnlyDictionary<string, int>).GetProperty("Item")!,
160+
new Expression[] {Expression.Constant("a")}),
161+
Expression.Constant(1));
162+
var predicate = Expression.Lambda<Func<C, bool>>(body, [x]);
163+
164+
var queryable = collection.AsQueryable()
165+
.Where(predicate);
166+
167+
if (linqProvider == LinqProvider.V2)
168+
{
169+
var exception = Record.Exception(() => Translate(collection, queryable));
170+
exception.Should().BeOfType<InvalidOperationException>();
171+
}
172+
else
173+
{
174+
var stages = Translate(collection, queryable);
175+
AssertStages(stages, "{ $match : { 'Dictionary.a' : 1 } }");
176+
177+
var results = queryable.ToList();
178+
results.Select(r => r.Id).Should().Equal(1);
179+
}
180+
}
181+
182+
private IMongoCollection<C> GetCollection(LinqProvider linqProvider)
183+
{
184+
var collection = GetCollection<C>("test", linqProvider);
185+
CreateCollection(
186+
collection,
187+
new C { Id = 1, Dictionary = new ReadOnlyDictionary<string, int>(new Dictionary<string, int> { ["a"] = 1 }) },
188+
new C { Id = 2, Dictionary = new ReadOnlyDictionary<string, int>(new Dictionary<string, int> { ["b"] = 2 }) });
189+
return collection;
190+
}
191+
192+
private class C
193+
{
194+
public int Id { get; set; }
195+
public IReadOnlyDictionary<string, int> Dictionary { get; set; }
196+
}
197+
}
198+
}

0 commit comments

Comments
 (0)