Skip to content

Commit 4f9b7ad

Browse files
authored
CSHARP-4690: InvalidCastException on converting underlying type to enum (#1121)
1 parent 5185f71 commit 4f9b7ad

File tree

6 files changed

+439
-127
lines changed

6 files changed

+439
-127
lines changed

src/MongoDB.Bson/Serialization/Serializers/EnumSerializer.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,26 @@ private TEnum ConvertStringToEnum(string value)
274274
return (TEnum)Enum.Parse(typeof(TEnum), value, ignoreCase: true);
275275
}
276276
}
277+
278+
/// <summary>
279+
/// Static factory class for EnumSerializer.
280+
/// </summary>
281+
public static class EnumSerializer
282+
{
283+
/// <summary>
284+
/// Creates a EnumSerializer.
285+
/// </summary>
286+
/// <param name="valueType">The value type.</param>
287+
/// <returns>A EnumSerializer</returns>
288+
public static IBsonSerializer Create(Type valueType)
289+
{
290+
if (!valueType.IsEnum)
291+
{
292+
throw new ArgumentException("Argument should be of enum type.", nameof(valueType));
293+
}
294+
295+
var enumSerializerType = typeof(EnumSerializer<>).MakeGenericType(valueType);
296+
return (IBsonSerializer)Activator.CreateInstance(enumSerializerType);
297+
}
298+
}
277299
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,29 +205,33 @@ private static AggregationExpression TranslateConvertEnumToUnderlyingType(UnaryE
205205

206206
private static AggregationExpression TranslateConvertUnderlyingTypeToEnum(UnaryExpression expression, AggregationExpression operandTranslation)
207207
{
208-
var sourceType = expression.Operand.Type;
209208
var targetType = expression.Type;
210209

211-
IBsonSerializer enumUnderlyingTypeSerializer;
212-
if (sourceType.IsNullable())
210+
var valueSerializer = operandTranslation.Serializer;
211+
if (valueSerializer is INullableSerializer nullableSerializer)
213212
{
214-
var nullableSerializer = (INullableSerializer)operandTranslation.Serializer;
215-
enumUnderlyingTypeSerializer = nullableSerializer.ValueSerializer;
216-
}
217-
else
218-
{
219-
enumUnderlyingTypeSerializer = operandTranslation.Serializer;
213+
valueSerializer = nullableSerializer.ValueSerializer;
220214
}
221215

222216
IBsonSerializer targetSerializer;
223-
var enumSerializer = ((IEnumUnderlyingTypeSerializer)enumUnderlyingTypeSerializer).EnumSerializer;
224-
if (targetType.IsNullableEnum())
217+
if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer)
225218
{
226-
targetSerializer = NullableSerializer.Create(enumSerializer);
219+
targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer;
227220
}
228221
else
229222
{
230-
targetSerializer = enumSerializer;
223+
var enumType = targetType;
224+
if (targetType.IsNullable(out var wrappedType))
225+
{
226+
enumType = wrappedType;
227+
}
228+
229+
targetSerializer = EnumSerializer.Create(enumType);
230+
}
231+
232+
if (targetType.IsNullableEnum())
233+
{
234+
targetSerializer = NullableSerializer.Create(targetSerializer);
231235
}
232236

233237
return new AggregationExpression(expression, operandTranslation.Ast, targetSerializer);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ public static AstFilterField Translate(TranslationContext context, UnaryExpressi
4444
return TranslateConvertEnumToUnderlyingType(field, targetType);
4545
}
4646

47+
if (IsConvertUnderlyingTypeToEnum(fieldType, targetType))
48+
{
49+
return TranslateConvertUnderlyingTypeToEnum(field, targetType);
50+
}
51+
4752
if (IsNumericConversion(fieldType, targetType))
4853
{
4954
return TranslateNumericConversion(field, targetType);
@@ -93,6 +98,13 @@ private static bool IsConvertToNullable(Type fieldType, Type targetType)
9398
targetType.GetGenericArguments()[0] == fieldType;
9499
}
95100

101+
private static bool IsConvertUnderlyingTypeToEnum(Type fieldType, Type targetType)
102+
{
103+
return
104+
targetType.IsEnumOrNullableEnum(out _, out var underlyingType) &&
105+
fieldType.IsSameAsOrNullableOf(underlyingType);
106+
}
107+
96108
private static bool IsNumericConversion(Type fieldType, Type targetType)
97109
{
98110
return IsNumericType(fieldType) && IsNumericType(targetType);
@@ -136,15 +148,10 @@ private static AstFilterField TranslateConvertEnumToUnderlyingType(AstFilterFiel
136148
enumSerializer = fieldSerializer;
137149
}
138150

139-
IBsonSerializer targetSerializer;
140-
var enumUnderlyingTypeSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer);
151+
var targetSerializer = EnumUnderlyingTypeSerializer.Create(enumSerializer);
141152
if (targetType.IsNullable())
142153
{
143-
targetSerializer = NullableSerializer.Create(enumUnderlyingTypeSerializer);
144-
}
145-
else
146-
{
147-
targetSerializer = enumUnderlyingTypeSerializer;
154+
targetSerializer = NullableSerializer.Create(targetSerializer);
148155
}
149156

150157
return AstFilter.Field(field.Path, targetSerializer);
@@ -170,6 +177,38 @@ private static AstFilterField TranslateConvertToNullable(AstFilterField field)
170177
return AstFilter.Field(field.Path, nullableSerializer);
171178
}
172179

180+
private static AstFilterField TranslateConvertUnderlyingTypeToEnum(AstFilterField field, Type targetType)
181+
{
182+
var valueSerializer = field.Serializer;
183+
if (valueSerializer is INullableSerializer nullableSerializer)
184+
{
185+
valueSerializer = nullableSerializer.ValueSerializer;
186+
}
187+
188+
IBsonSerializer targetSerializer;
189+
if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer)
190+
{
191+
targetSerializer = enumUnderlyingTypeSerializer.EnumSerializer;
192+
}
193+
else
194+
{
195+
var enumType = targetType;
196+
if (targetType.IsNullable(out var wrappedType))
197+
{
198+
enumType = wrappedType;
199+
}
200+
201+
targetSerializer = EnumSerializer.Create(enumType);
202+
}
203+
204+
if (targetType.IsNullableEnum())
205+
{
206+
targetSerializer = NullableSerializer.Create(targetSerializer);
207+
}
208+
209+
return AstFilter.Field(field.Path, targetSerializer);
210+
}
211+
173212
private static AstFilterField TranslateNumericConversion(AstFilterField field, Type targetType)
174213
{
175214
IBsonSerializer targetTypeSerializer = targetType switch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using FluentAssertions;
17+
using MongoDB.Driver.Linq;
18+
using Xunit;
19+
20+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
21+
{
22+
public class ConvertExpressionToAggregationExpressionTranslatorTests : Linq3IntegrationTest
23+
{
24+
[Fact]
25+
public void Should_translate_to_derived_class_on_method_call()
26+
{
27+
var collection = GetCollection();
28+
var queryable = collection.AsQueryable()
29+
.Select(p => new DerivedClass
30+
{
31+
Id = p.Id,
32+
A = ((DerivedClass)p).A.ToUpper()
33+
});
34+
35+
var stages = Translate(collection, queryable);
36+
AssertStages(
37+
stages,
38+
"{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }");
39+
40+
var result = queryable.Single();
41+
result.Id.Should().Be(1);
42+
result.A.Should().Be("ABC");
43+
}
44+
45+
[Fact]
46+
public void Should_translate_to_derived_class_on_projection()
47+
{
48+
var collection = GetCollection();
49+
var queryable = collection.AsQueryable()
50+
.Select(p => new DerivedClass()
51+
{
52+
Id = p.Id,
53+
A = ((DerivedClass)p).A
54+
});
55+
56+
var stages = Translate(collection, queryable);
57+
AssertStages(
58+
stages,
59+
"{ '$project' : { _id : '$_id', A : '$A' } }");
60+
61+
var result = queryable.Single();
62+
result.Id.Should().Be(1);
63+
result.A.Should().Be("abc");
64+
}
65+
66+
[Fact]
67+
public void Project_using_convert_underlying_type_to_enum_should_work()
68+
{
69+
var collection = GetCollection();
70+
var queryable = collection.AsQueryable()
71+
.Select(p => new ProjectedModel
72+
{
73+
Id = p.Id,
74+
Enum = (Enum)p.EnumAsInt,
75+
EnumComparisonResult = (Enum)p.EnumAsInt == Enum.Two,
76+
});
77+
78+
var stages = Translate(collection, queryable);
79+
AssertStages(
80+
stages,
81+
"{ '$project' : { _id : '$_id', Enum : '$EnumAsInt', EnumComparisonResult : { $eq : ['$EnumAsInt', 2] } } }");
82+
83+
var result = queryable.Single();
84+
result.Id.Should().Be(1);
85+
result.Enum.Should().Be(Enum.Two);
86+
}
87+
88+
[Fact]
89+
public void Project_using_convert_nullable_underlying_type_to_enum_should_work()
90+
{
91+
var collection = GetCollection();
92+
var queryable = collection.AsQueryable()
93+
.Select(p => new ProjectedModel
94+
{
95+
Id = p.Id,
96+
Enum = (Enum)p.EnumAsNullableInt,
97+
EnumComparisonResult = (Enum)p.EnumAsNullableInt == Enum.Two,
98+
});
99+
100+
var stages = Translate(collection, queryable);
101+
AssertStages(
102+
stages,
103+
"{ '$project' : { _id : '$_id', Enum : '$EnumAsNullableInt', EnumComparisonResult : { $eq : ['$EnumAsNullableInt', 2] } } }");
104+
105+
var result = queryable.Single();
106+
result.Id.Should().Be(1);
107+
result.Enum.Should().Be(Enum.Two);
108+
}
109+
110+
[Fact]
111+
public void Project_using_convert_nullable_underlying_type_to_nullable_enum_should_work()
112+
{
113+
var collection = GetCollection();
114+
var queryable = collection.AsQueryable()
115+
.Select(p => new ProjectedModel
116+
{
117+
Id = p.Id,
118+
NullableEnum = (Enum?)p.EnumAsNullableInt
119+
});
120+
121+
var stages = Translate(collection, queryable);
122+
AssertStages(
123+
stages,
124+
"{ '$project' : { _id : '$_id', NullableEnum : '$EnumAsNullableInt' } }");
125+
126+
var result = queryable.Single();
127+
result.Id.Should().Be(1);
128+
result.NullableEnum.Should().Be(Enum.Two);
129+
}
130+
131+
[Fact]
132+
public void Project_using_convert_enum_to_underlying_type_should_work()
133+
{
134+
var collection = GetCollection();
135+
var queryable = collection.AsQueryable()
136+
.Select(p => new ProjectedModel
137+
{
138+
Id = p.Id,
139+
EnumAsInt = (int)p.Enum
140+
});
141+
142+
var stages = Translate(collection, queryable);
143+
AssertStages(
144+
stages,
145+
"{ '$project' : { _id : '$_id', EnumAsInt : '$Enum' } }");
146+
147+
var result = queryable.Single();
148+
result.Id.Should().Be(1);
149+
result.EnumAsInt.Should().Be(2);
150+
}
151+
152+
[Fact]
153+
public void Project_using_convert_nullable_enum_to_underlying_type_work()
154+
{
155+
var collection = GetCollection();
156+
var queryable = collection.AsQueryable()
157+
.Select(p => new ProjectedModel
158+
{
159+
Id = p.Id,
160+
EnumAsInt = (int)p.NullableEnum
161+
});
162+
163+
var stages = Translate(collection, queryable);
164+
AssertStages(
165+
stages,
166+
"{ '$project' : { _id : '$_id', EnumAsInt : '$NullableEnum' } }");
167+
168+
var result = queryable.Single();
169+
result.Id.Should().Be(1);
170+
result.EnumAsInt.Should().Be(2);
171+
}
172+
173+
[Fact]
174+
public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work()
175+
{
176+
var collection = GetCollection();
177+
var queryable = collection.AsQueryable()
178+
.Select(p => new ProjectedModel
179+
{
180+
Id = p.Id,
181+
EnumAsNullableInt = (int)p.NullableEnum
182+
});
183+
184+
var stages = Translate(collection, queryable);
185+
AssertStages(
186+
stages,
187+
"{ '$project' : { _id : '$_id', EnumAsNullableInt : '$NullableEnum' } }");
188+
189+
var result = queryable.Single();
190+
result.Id.Should().Be(1);
191+
result.EnumAsNullableInt.Should().Be(2);
192+
}
193+
194+
195+
196+
private IMongoCollection<BaseClass> GetCollection()
197+
{
198+
var collection = GetCollection<BaseClass>("test");
199+
CreateCollection(collection, new DerivedClass()
200+
{
201+
Id = 1,
202+
A = "abc",
203+
Enum = Enum.Two,
204+
NullableEnum = Enum.Two,
205+
EnumAsInt = 2,
206+
EnumAsNullableInt = 2
207+
});
208+
return collection;
209+
}
210+
211+
private class BaseClass
212+
{
213+
public int Id { get; set; }
214+
public Enum Enum { get; set; }
215+
public Enum? NullableEnum { get; set; }
216+
public int EnumAsInt { get; set; }
217+
public int? EnumAsNullableInt { get; set; }
218+
}
219+
220+
private class DerivedClass : BaseClass
221+
{
222+
public string A { get; set; }
223+
}
224+
225+
private class ProjectedModel
226+
{
227+
public int Id { get; set; }
228+
public Enum Enum { get; set; }
229+
public Enum? NullableEnum { get; set; }
230+
public int EnumAsInt { get; set; }
231+
public int? EnumAsNullableInt { get; set; }
232+
public bool EnumComparisonResult { get; set; }
233+
}
234+
235+
private enum Enum
236+
{
237+
One = 1,
238+
Two = 2
239+
}
240+
}
241+
}

0 commit comments

Comments
 (0)