Skip to content

Commit 3c2d4ff

Browse files
committed
CSHARP-4567: Support implicit conversion when lambda body returns a type assignable to the lambda return type.
1 parent 9eaf604 commit 3c2d4ff

File tree

3 files changed

+107
-4
lines changed

3 files changed

+107
-4
lines changed

src/MongoDB.Bson/Serialization/Serializers/DowncastingSerializer.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public static IBsonSerializer Create(
6565
/// </summary>
6666
/// <typeparam name="TBase">The base type.</typeparam>
6767
/// <typeparam name="TDerived">The derived type.</typeparam>
68-
public class DowncastingSerializer<TBase, TDerived> : SerializerBase<TBase>, IBsonDocumentSerializer, IDowncastingSerializer
68+
public class DowncastingSerializer<TBase, TDerived> : SerializerBase<TBase>, IBsonArraySerializer, IBsonDocumentSerializer, IDowncastingSerializer
6969
where TDerived : TBase
7070
{
7171
private readonly IBsonSerializer<TDerived> _derivedSerializer;
@@ -94,13 +94,24 @@ public DowncastingSerializer(IBsonSerializer<TDerived> derivedSerializer)
9494
/// <inheritdoc/>
9595
public override TBase Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
9696
{
97-
return _derivedSerializer.Deserialize(context, args);
97+
return _derivedSerializer.Deserialize(context);
9898
}
9999

100100
/// <inheritdoc/>
101101
public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TBase value)
102102
{
103-
_derivedSerializer.Serialize(context, args, (TDerived)value);
103+
_derivedSerializer.Serialize(context, (TDerived)value);
104+
}
105+
106+
/// <inheritdoc/>
107+
public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo)
108+
{
109+
if (_derivedSerializer is IBsonArraySerializer arraySerializer)
110+
{
111+
return arraySerializer.TryGetItemSerializationInfo(out serializationInfo);
112+
}
113+
114+
throw new InvalidOperationException($"The class {_derivedSerializer.GetType().FullName} does not implement IBsonArraySerializer.");
104115
}
105116

106117
/// <inheritdoc/>

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using System.Linq;
1818
using System.Linq.Expressions;
1919
using MongoDB.Bson.Serialization;
20+
using MongoDB.Bson.Serialization.Serializers;
2021
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
2122
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2223
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
@@ -115,7 +116,25 @@ public static AggregationExpression TranslateLambdaBody(
115116
context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true) :
116117
context.CreateSymbol(parameterExpression, parameterSerializer, isCurrent: false);
117118
var lambdaContext = context.WithSymbol(parameterSymbol);
118-
return Translate(lambdaContext, lambdaExpression.Body);
119+
var translatedBody = Translate(lambdaContext, lambdaExpression.Body);
120+
121+
var lambdaReturnType = lambdaExpression.ReturnType;
122+
var bodySerializer = translatedBody.Serializer;
123+
var bodyType = bodySerializer.ValueType;
124+
if (bodyType != lambdaReturnType)
125+
{
126+
if (lambdaReturnType.IsAssignableFrom(bodyType))
127+
{
128+
var downcastingSerializer = DowncastingSerializer.Create(baseType: lambdaReturnType, derivedType: bodyType, derivedTypeSerializer: bodySerializer);
129+
translatedBody = new AggregationExpression(translatedBody.Expression, translatedBody.Ast, downcastingSerializer);
130+
}
131+
else
132+
{
133+
throw new ExpressionNotSupportedException(lambdaExpression, because: "lambda body type is not convertible to lambda return type");
134+
}
135+
}
136+
137+
return translatedBody;
119138
}
120139
}
121140
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq.Expressions;
18+
using FluentAssertions;
19+
using MongoDB.Driver.Core.Misc;
20+
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
21+
using MongoDB.Driver.Linq;
22+
using MongoDB.TestHelpers.XunitExtensions;
23+
using Xunit;
24+
25+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
26+
{
27+
public class CSharp4567Tests : Linq3IntegrationTest
28+
{
29+
[Theory]
30+
[ParameterAttributeData]
31+
public void Projection_to_derived_type_should_work(
32+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
33+
{
34+
RequireServer.Check().Supports(Feature.FindProjectionExpressions);
35+
var collection = GetCollection(linqProvider);
36+
Expression<Func<C, object>> projection = x => new R { X = x.Id };
37+
38+
var find = collection.Find("{}").Project(projection);
39+
40+
var translatedProjection = TranslateFindProjection(collection, find);
41+
if (linqProvider == LinqProvider.V2)
42+
{
43+
translatedProjection.Should().Be("{ _id : 1 }");
44+
}
45+
else
46+
{
47+
translatedProjection.Should().Be("{ X : '$_id', _id : 0 }");
48+
}
49+
50+
var result = (R)find.Single();
51+
result.X.Should().Be(1);
52+
}
53+
54+
private IMongoCollection<C> GetCollection(LinqProvider linqProvider)
55+
{
56+
var collection = GetCollection<C>("test", linqProvider);
57+
CreateCollection(
58+
collection,
59+
new C { Id = 1 });
60+
return collection;
61+
}
62+
63+
private class C
64+
{
65+
public int Id { get; set; }
66+
}
67+
68+
private class R
69+
{
70+
public int X { get; set; }
71+
}
72+
}
73+
}

0 commit comments

Comments
 (0)