Skip to content

Commit e74a0a9

Browse files
committed
CSHARP-4747: Add support for $set stage.
1 parent c1e3424 commit e74a0a9

File tree

17 files changed

+1269
-0
lines changed

17 files changed

+1269
-0
lines changed

src/MongoDB.Driver.Core/Core/Misc/Feature.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ public class Feature
116116
private static readonly Feature __serverExtractsUsernameFromX509Certificate = new Feature("ServerExtractsUsernameFromX509Certificate", WireVersion.Server34);
117117
private static readonly Feature __serverReturnsResumableChangeStreamErrorLabel = new Feature("ServerReturnsResumableChangeStreamErrorLabel", WireVersion.Server44);
118118
private static readonly Feature __serverReturnsRetryableWriteErrorLabel = new Feature("ServerReturnsRetryableWriteErrorLabel", WireVersion.Server44);
119+
private static readonly Feature __setStage = new Feature("SetStage", WireVersion.Server42);
119120
private static readonly Feature __setWindowFields = new Feature("SetWindowFields", WireVersion.Server50);
120121
private static readonly Feature __setWindowFieldsLocf = new Feature("SetWindowFieldsLocf", WireVersion.Server52);
121122
private static readonly Feature __shardedTransactions = new Feature("ShardedTransactions", WireVersion.Server42);
@@ -652,6 +653,11 @@ public class Feature
652653
/// </summary>
653654
public static Feature ServerReturnsRetryableWriteErrorLabel => __serverReturnsRetryableWriteErrorLabel;
654655

656+
/// <summary>
657+
/// Gets the $set stage feature.
658+
/// </summary>
659+
public static Feature SetStage => __setStage;
660+
655661
/// <summary>
656662
/// Gets the set window fields feature.
657663
/// </summary>

src/MongoDB.Driver/AggregateFluent.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ public override IAggregateFluent<SearchMetaResult> SearchMeta(
274274
return WithPipeline(_pipeline.SearchMeta(searchDefinition, indexName, count));
275275
}
276276

277+
public override IAggregateFluent<TResult> Set(SetFieldDefinitions<TResult> fields)
278+
{
279+
return WithPipeline(_pipeline.Set(fields));
280+
}
281+
277282
public override IAggregateFluent<BsonDocument> SetWindowFields<TWindowFields>(
278283
AggregateExpressionDefinition<ISetWindowFieldsPartition<TResult>, TWindowFields> output)
279284
{

src/MongoDB.Driver/AggregateFluentBase.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ public virtual IAggregateFluent<SearchMetaResult> SearchMeta(
245245
throw new NotImplementedException();
246246
}
247247

248+
/// <inheritdoc />
249+
public virtual IAggregateFluent<TResult> Set(SetFieldDefinitions<TResult> fields) => throw new NotImplementedException();
250+
248251
/// <inheritdoc />
249252
public virtual IAggregateFluent<BsonDocument> SetWindowFields<TWindowFields>(
250253
AggregateExpressionDefinition<ISetWindowFieldsPartition<TResult>, TWindowFields> output)

src/MongoDB.Driver/Builders.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ public static class Builders<TDocument>
3232
/// <summary>Gets a <see cref="ProjectionDefinitionBuilder{TDocument}"/>.</summary>
3333
public static ProjectionDefinitionBuilder<TDocument> Projection { get; } = new ProjectionDefinitionBuilder<TDocument>();
3434

35+
/// <summary>Gets a <see cref="SetFieldDefinitionsBuilder{TDocument}"/>.</summary>
36+
public static SetFieldDefinitionsBuilder<TDocument> SetFields { get; } = new SetFieldDefinitionsBuilder<TDocument>();
37+
3538
/// <summary>Gets a <see cref="SortDefinitionBuilder{TDocument}"/>.</summary>
3639
public static SortDefinitionBuilder<TDocument> Sort { get; } = new SortDefinitionBuilder<TDocument>();
3740

src/MongoDB.Driver/IAggregateFluent.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,13 @@ IAggregateFluent<TNewResult> Lookup<TForeignDocument, TAsElement, TAs, TNewResul
345345
/// <returns>The fluent aggregate interface.</returns>
346346
IAggregateFluent<TNewResult> ReplaceWith<TNewResult>(AggregateExpressionDefinition<TResult, TNewResult> newRoot);
347347

348+
/// <summary>
349+
/// Appends a $set stage to the pipeline.
350+
/// </summary>
351+
/// <param name="fields">The fields to set.</param>
352+
/// <returns>The fluent aggregate interface.</returns>
353+
IAggregateFluent<TResult> Set(SetFieldDefinitions<TResult> fields);
354+
348355
/// <summary>
349356
/// Appends a $setWindowFields to the pipeline.
350357
/// </summary>

src/MongoDB.Driver/IAggregateFluentExtensions.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,22 @@ public static IAggregateFluent<TNewResult> ReplaceWith<TResult, TNewResult>(
621621
return aggregate.AppendStage(PipelineStageDefinitionBuilder.ReplaceWith(newRoot));
622622
}
623623

624+
/// <summary>
625+
/// Appends a $set stage to the pipeline.
626+
/// </summary>
627+
/// <typeparam name="TResult">The type of the result.</typeparam>
628+
/// <typeparam name="TFields">The type of object specifying the fields to set.</typeparam>
629+
/// <param name="aggregate">The aggregate.</param>
630+
/// <param name="fields">The fields to set.</param>
631+
/// <returns>The fluent aggregate interface.</returns>
632+
public static IAggregateFluent<TResult> Set<TResult, TFields>(
633+
this IAggregateFluent<TResult> aggregate,
634+
Expression<Func<TResult, TFields>> fields)
635+
{
636+
Ensure.IsNotNull(aggregate, nameof(aggregate));
637+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Set(fields));
638+
}
639+
624640
/// <summary>
625641
/// Appends a $setWindowFields to the pipeline.
626642
/// </summary>

src/MongoDB.Driver/Linq/Linq2Implementation/LinqProviderAdapterV2.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,13 @@ internal override RenderedProjectionDefinition<TOutput> TranslateExpressionToPro
160160

161161
return AggregateProjectTranslator.Translate<TInput, TOutput>(expression, inputSerializer, serializerRegistry, translationOptions);
162162
}
163+
164+
internal override BsonDocument TranslateExpressionToSetStage<TDocument, TFields>(
165+
Expression<Func<TDocument, TFields>> expression,
166+
IBsonSerializer<TDocument> documentSerializer,
167+
IBsonSerializerRegistry serializerRegistry)
168+
{
169+
throw new NotSupportedException("Set with an Expression is only supported when using LINQ3.");
170+
}
163171
}
164172
}

src/MongoDB.Driver/Linq/Linq3Implementation/LinqProviderAdapterV3.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
2727
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators;
2828
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators;
29+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToSetStageTranslators;
2930

3031
namespace MongoDB.Driver.Linq.Linq3Implementation
3132
{
@@ -180,5 +181,19 @@ private RenderedProjectionDefinition<TOutput> TranslateExpressionToProjection<TI
180181
var renderedProjection = new BsonDocument(specifications.Select(specification => specification.RenderAsElement()));
181182
return new RenderedProjectionDefinition<TOutput>(renderedProjection, (IBsonSerializer<TOutput>)projectionSerializer);
182183
}
184+
185+
internal override BsonDocument TranslateExpressionToSetStage<TDocument, TFields>(
186+
Expression<Func<TDocument, TFields>> expression,
187+
IBsonSerializer<TDocument> documentSerializer,
188+
IBsonSerializerRegistry serializerRegistry)
189+
{
190+
var context = TranslationContext.Create(expression, documentSerializer); // do not partially evaluate expression
191+
var parameter = expression.Parameters.Single();
192+
var symbol = context.CreateSymbolWithVarName(parameter, varName: "ROOT", documentSerializer, isCurrent: true);
193+
context = context.WithSymbol(symbol);
194+
var setStage = ExpressionToSetStageTranslator.Translate(context, documentSerializer, expression);
195+
var simplifiedSetStage = AstSimplifier.SimplifyAndConvert(setStage);
196+
return simplifiedSetStage.Render().AsBsonDocument;
197+
}
183198
}
184199
}

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using System;
1717
using System.Collections.Generic;
1818
using System.Linq;
19+
using System.Runtime.CompilerServices;
1920

2021
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2122
{
@@ -147,6 +148,15 @@ public static bool Is(this Type type, Type comparand)
147148
return false;
148149
}
149150

151+
public static bool IsAnonymous(this Type type)
152+
{
153+
// don't test for too many things in case implementation details change in the future
154+
return
155+
type.GetCustomAttributes(false).Any(x => x is CompilerGeneratedAttribute) &&
156+
(type.IsGenericType || type.GetProperties().Length == 0) && // type is not generic for "new { }"
157+
type.Name.Contains("Anon"); // don't check for more than "Anon" so it works in mono also
158+
}
159+
150160
public static bool IsEnum(this Type type, out Type underlyingType)
151161
{
152162
if (type.IsEnum)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using System.Linq.Expressions;
18+
using System.Reflection;
19+
using MongoDB.Bson.Serialization;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
25+
26+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToSetStageTranslators
27+
{
28+
internal static class ExpressionToSetStageTranslator
29+
{
30+
public static AstStage Translate(TranslationContext context, IBsonSerializer inputSerializer, LambdaExpression expression)
31+
{
32+
if (inputSerializer is not IBsonDocumentSerializer documentSerializer)
33+
{
34+
throw new ExpressionNotSupportedException(expression, because: $"serializer {inputSerializer.GetType()} does not implement IBsonDocumentSerializer");
35+
}
36+
37+
if (IsNewAnonymousClass(expression, out var newExpression))
38+
{
39+
return TranslateNewAnonymousClass(context, documentSerializer, newExpression);
40+
}
41+
42+
if (IsNewWithOptionalMemberInitializers(expression, out var memberInitExpression))
43+
{
44+
return TranslateNewWithOptionalMemberInitializers(context, documentSerializer, memberInitExpression);
45+
}
46+
47+
throw new ExpressionNotSupportedException(expression, because: "expression is not valid for Set");
48+
}
49+
50+
private static bool IsNewAnonymousClass(LambdaExpression expression, out NewExpression newExpression)
51+
{
52+
if (expression.Body is NewExpression tempNewExpression &&
53+
tempNewExpression.Type.IsAnonymous())
54+
{
55+
newExpression = tempNewExpression;
56+
return true;
57+
}
58+
59+
newExpression = null;
60+
return false;
61+
}
62+
63+
private static bool IsNewWithOptionalMemberInitializers(LambdaExpression expression, out MemberInitExpression memberInitExpression)
64+
{
65+
if (expression.Body.NodeType == ExpressionType.New)
66+
{
67+
memberInitExpression = null;
68+
return true;
69+
}
70+
71+
if (expression.Body is MemberInitExpression tempMemberInitExpression)
72+
{
73+
var constructor = tempMemberInitExpression.NewExpression.Constructor; // will be null for default constructor of struct
74+
if (constructor == null || IsDefaultConstructor(constructor) || IsCopyConstructor(constructor))
75+
{
76+
memberInitExpression = tempMemberInitExpression;
77+
return true;
78+
}
79+
}
80+
81+
memberInitExpression = null;
82+
return false;
83+
84+
static bool IsDefaultConstructor(ConstructorInfo constructor)
85+
=> constructor.GetParameters().Length == 0;
86+
87+
static bool IsCopyConstructor(ConstructorInfo constructor)
88+
=>
89+
constructor.GetParameters() is var parameters &&
90+
parameters.Length == 1 &&
91+
parameters[0].ParameterType == constructor.DeclaringType;
92+
}
93+
94+
private static AstStage TranslateNewAnonymousClass(TranslationContext context, IBsonDocumentSerializer documentSerializer, NewExpression newExpression)
95+
{
96+
var members = newExpression.Members; // will be null in the case of "new { }"
97+
var arguments = newExpression.Arguments;
98+
99+
var fields = new List<AstComputedField>();
100+
if (members != null)
101+
{
102+
for (var i = 0; i < members.Count; i++)
103+
{
104+
var member = members[i];
105+
var valueExpression = PartialEvaluator.EvaluatePartially(arguments[i]);
106+
var computedField = CreateComputedField(context, documentSerializer, member, valueExpression);
107+
fields.Add(computedField);
108+
}
109+
}
110+
111+
return AstStage.Set(fields);
112+
}
113+
114+
private static AstStage TranslateNewWithOptionalMemberInitializers(TranslationContext context, IBsonDocumentSerializer documentSerializer, MemberInitExpression memberInitExpression)
115+
{
116+
var fields = new List<AstComputedField>();
117+
if (memberInitExpression != null)
118+
{
119+
var bindings = memberInitExpression.Bindings;
120+
121+
for (var i = 0; i < bindings.Count; i++)
122+
{
123+
var binding = bindings[i];
124+
if (binding is not MemberAssignment assignment)
125+
{
126+
throw new ExpressionNotSupportedException(memberInitExpression, because: $"the member initializer for {binding.Member.Name} is not a simple assignment");
127+
}
128+
129+
var member = binding.Member;
130+
var valueExpression = PartialEvaluator.EvaluatePartially(assignment.Expression);
131+
var computedField = CreateComputedField(context, documentSerializer, member, valueExpression);
132+
fields.Add(computedField);
133+
}
134+
}
135+
136+
return AstStage.Set(fields);
137+
}
138+
139+
private static AstComputedField CreateComputedField(TranslationContext context, IBsonDocumentSerializer documentSerializer, MemberInfo member, Expression valueExpression)
140+
{
141+
string elementName;
142+
AstExpression valueAst;
143+
if (documentSerializer.TryGetMemberSerializationInfo(member.Name, out var serializationInfo))
144+
{
145+
elementName = serializationInfo.ElementName;
146+
var memberSerializer = serializationInfo.Serializer;
147+
148+
if (valueExpression is ConstantExpression constantValueExpression)
149+
{
150+
var value = constantValueExpression.Value;
151+
var serializedValue = SerializationHelper.SerializeValue(memberSerializer, value);
152+
valueAst = AstExpression.Constant(serializedValue);
153+
}
154+
else
155+
{
156+
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
157+
ThrowIfMemberAndValueSerializersAreNotCompatible(valueExpression, memberSerializer, valueTranslation.Serializer);
158+
valueAst = valueTranslation.Ast;
159+
}
160+
}
161+
else
162+
{
163+
elementName = member.Name;
164+
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
165+
valueAst = valueTranslation.Ast;
166+
}
167+
168+
return AstExpression.ComputedField(elementName, valueAst);
169+
}
170+
171+
private static void ThrowIfMemberAndValueSerializersAreNotCompatible(Expression expression, IBsonSerializer memberSerializer, IBsonSerializer valueSerializer)
172+
{
173+
// TODO: depends on CSHARP-3315
174+
if (!memberSerializer.Equals(valueSerializer))
175+
{
176+
throw new ExpressionNotSupportedException(expression, because: "member and value serializers are not compatible");
177+
}
178+
}
179+
}
180+
}

0 commit comments

Comments
 (0)