Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ public interface IKeyValuePairSerializer
BsonType Representation { get; }
}

/// <summary>
/// An extended interface for KeyValuePairSerializer that provides access to key and value serializers.
/// </summary>
public interface IKeyValuePairSerializerV2 : IKeyValuePairSerializer
{
/// <summary>
/// Gets the key serializer.
/// </summary>
IBsonSerializer KeySerializer { get; }

/// <summary>
/// Gets the value serializer.
/// </summary>
IBsonSerializer ValueSerializer { get; }
}

/// <summary>
/// Static factory class for KeyValuePairSerializers.
/// </summary>
Expand Down Expand Up @@ -61,7 +77,7 @@ public static IBsonSerializer Create(
public sealed class KeyValuePairSerializer<TKey, TValue> :
StructSerializerBase<KeyValuePair<TKey, TValue>>,
IBsonDocumentSerializer,
IKeyValuePairSerializer
IKeyValuePairSerializerV2
{
// private constants
private static class Flags
Expand Down Expand Up @@ -191,6 +207,16 @@ public IBsonSerializer<TValue> ValueSerializer
get { return _lazyValueSerializer.Value; }
}

/// <summary>
/// Gets the key serializer.
/// </summary>
IBsonSerializer IKeyValuePairSerializerV2.KeySerializer => KeySerializer;

/// <summary>
/// Gets the value serializer.
/// </summary>
IBsonSerializer IKeyValuePairSerializerV2.ValueSerializer => ValueSerializer;

// public methods
/// <summary>
/// Deserializes a value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ public static AstExpression ArrayElemAt(AstExpression array, AstExpression index
return new AstBinaryExpression(AstBinaryOperator.ArrayElemAt, array, index);
}

public static AstExpression ArrayToObject(AstExpression arg)
{
return new AstUnaryExpression(AstUnaryOperator.ArrayToObject, arg);
}

public static AstExpression Avg(AstExpression array)
{
return new AstUnaryExpression(AstUnaryOperator.Avg, array);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -71,11 +70,16 @@ public static TranslatedExpression Translate(TranslationContext context, MemberE

if (!DocumentSerializerHelper.AreMembersRepresentedAsFields(containerTranslation.Serializer, out _))
{
if (member is PropertyInfo propertyInfo && propertyInfo.Name == "Length")
if (member is PropertyInfo propertyInfo && propertyInfo.Name == "Length")
{
return LengthPropertyToAggregationExpressionTranslator.Translate(context, expression);
}

if (TryTranslateKeyValuePairProperty(expression, containerTranslation, member, out var translatedKeyValuePairProperty))
{
return translatedKeyValuePairProperty;
}

if (TryTranslateCollectionCountProperty(expression, containerTranslation, member, out var translatedCount))
{
return translatedCount;
Expand Down Expand Up @@ -126,11 +130,20 @@ private static bool TryTranslateCollectionCountProperty(MemberExpression express
{
if (EnumerableProperty.IsCountProperty(expression))
{
SerializationHelper.EnsureRepresentationIsArray(expression, container.Serializer);
AstExpression ast;

var ast = AstExpression.Size(container.Ast);
var serializer = Int32Serializer.Instance;
if (container.Serializer is IBsonDictionarySerializer dictionarySerializer &&
dictionarySerializer.DictionaryRepresentation == DictionaryRepresentation.Document)
{
ast = AstExpression.Size(AstExpression.ObjectToArray(container.Ast));
}
else
{
SerializationHelper.EnsureRepresentationIsArray(expression, container.Serializer);
ast = AstExpression.Size(container.Ast);
}

var serializer = Int32Serializer.Instance;
result = new TranslatedExpression(expression, ast, serializer);
return true;
}
Expand Down Expand Up @@ -213,6 +226,16 @@ private static bool TryTranslateDictionaryProperty(TranslationContext context, M

switch (propertyInfo.Name)
{
case "Count":
var countAst = dictionaryRepresentation switch
{
DictionaryRepresentation.ArrayOfDocuments or DictionaryRepresentation.ArrayOfArrays => AstExpression.Size(containerAst),
_ => throw new ExpressionNotSupportedException(expression, $"Unexpected dictionary representation: {dictionaryRepresentation}")
};
var countSerializer = Int32Serializer.Instance;
translatedDictionaryProperty = new TranslatedExpression(expression, countAst, countSerializer);
return true;

case "Keys":
var keysAst = dictionaryRepresentation switch
{
Expand Down Expand Up @@ -261,5 +284,36 @@ private static bool TryTranslateDictionaryProperty(TranslationContext context, M
translatedDictionaryProperty = null;
return false;
}

private static bool TryTranslateKeyValuePairProperty(MemberExpression expression, TranslatedExpression container, MemberInfo memberInfo, out TranslatedExpression result)
{
result = null;

if (container.Expression.Type.IsGenericType &&
container.Expression.Type.GetGenericTypeDefinition() == typeof(KeyValuePair<,>) &&
container.Serializer is IKeyValuePairSerializerV2 { Representation: BsonType.Array } kvpSerializer)
{
AstExpression ast;
IBsonSerializer serializer;

switch (memberInfo.Name)
{
case "Key":
ast = AstExpression.ArrayElemAt(container.Ast, 0);
serializer = kvpSerializer.KeySerializer;
break;
case "Value":
ast = AstExpression.ArrayElemAt(container.Ast, 1);
serializer = kvpSerializer.ValueSerializer;
break;
default:
throw new ExpressionNotSupportedException(expression);
}
result = new TranslatedExpression(expression, ast, serializer);
return true;
}

return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
Expand Down Expand Up @@ -123,7 +124,21 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
}
else
{
ast = AstExpression.Avg(sourceTranslation.Ast);
var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
if (sourceItemSerializer is IWrappedValueSerializer wrappedValueSerializer)
{
var itemVar = AstExpression.Var("item");
var unwrappedItemAst = AstExpression.GetField(itemVar, wrappedValueSerializer.FieldName);
ast = AstExpression.Avg(
AstExpression.Map(
input: sourceTranslation.Ast,
@as: itemVar,
@in: unwrappedItemAst));
}
else
{
ast = AstExpression.Avg(sourceTranslation.Ast);
}
}
IBsonSerializer serializer = expression.Type switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,57 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
{
var dictionaryExpression = expression.Object;
var keyExpression = arguments[0];
return TranslateContainsKey(context, expression, dictionaryExpression, keyExpression);
}

throw new ExpressionNotSupportedException(expression);
}

var dictionaryTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dictionaryExpression);
var dictionarySerializer = GetDictionarySerializer(expression, dictionaryTranslation);
var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation;
public static TranslatedExpression TranslateContainsKey(TranslationContext context, Expression expression, Expression dictionaryExpression, Expression keyExpression)
{
var dictionaryTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dictionaryExpression);
var dictionarySerializer = GetDictionarySerializer(expression, dictionaryTranslation);
var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation;

AstExpression ast;
switch (dictionaryRepresentation)
{
case DictionaryRepresentation.Document:
AstExpression ast;
switch (dictionaryRepresentation)
{
case DictionaryRepresentation.Document:
{
var keyFieldName = GetKeyFieldName(context, expression, keyExpression, dictionarySerializer.KeySerializer);
ast = AstExpression.IsNotMissing(AstExpression.GetField(dictionaryTranslation.Ast, keyFieldName));
break;
}

default:
throw new ExpressionNotSupportedException(expression, because: $"ContainsKey is not supported when DictionaryRepresentation is: {dictionaryRepresentation}");
}
case DictionaryRepresentation.ArrayOfDocuments:
{
var keyFieldName = GetKeyFieldName(context, expression, keyExpression, dictionarySerializer.KeySerializer);
var kvpVar = AstExpression.Var("kvp");
var keysArray = AstExpression.Map(
input: dictionaryTranslation.Ast,
@as: kvpVar,
@in: AstExpression.GetField(kvpVar, "k"));
ast = AstExpression.In(keyFieldName, keysArray);
break;
}

case DictionaryRepresentation.ArrayOfArrays:
{
var keyFieldName = GetKeyFieldName(context, expression, keyExpression, dictionarySerializer.KeySerializer);
var kvpVar = AstExpression.Var("kvp");
var keysArray = AstExpression.Map(
input: dictionaryTranslation.Ast,
@as: kvpVar,
@in: AstExpression.ArrayElemAt(kvpVar, 0));
ast = AstExpression.In(keyFieldName, keysArray);
break;
}

return new TranslatedExpression(expression, ast, BooleanSerializer.Instance);
default:
throw new ExpressionNotSupportedException(expression, because: $"DictionaryRepresentation: {dictionaryRepresentation} is not supported.");
}

throw new ExpressionNotSupportedException(expression);
return new TranslatedExpression(expression, ast, BooleanSerializer.Instance);
}

private static AstExpression GetKeyFieldName(TranslationContext context, Expression expression, Expression keyExpression, IBsonSerializer keySerializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* limitations under the License.
*/

using System.Collections.Generic;
using System.Linq.Expressions;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
Expand All @@ -33,6 +34,11 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC

if (IsEnumerableContainsMethod(expression, out var sourceExpression, out var valueExpression))
{
if (TryTranslateDictionaryKeysOrValuesContains(context, expression, sourceExpression, valueExpression, out var dictionaryTranslation))
{
return dictionaryTranslation;
}

return TranslateEnumerableContains(context, expression, sourceExpression, valueExpression);
}

Expand Down Expand Up @@ -83,5 +89,48 @@ private static TranslatedExpression TranslateEnumerableContains(TranslationConte

return new TranslatedExpression(expression, ast, BooleanSerializer.Instance);
}

private static bool TryTranslateDictionaryKeysOrValuesContains(
TranslationContext context,
Expression expression,
Expression sourceExpression,
Expression valueExpression,
out TranslatedExpression translation)
{
translation = null;

if (sourceExpression is not MemberExpression memberExpression)
{
return false;
}

var memberName = memberExpression.Member.Name;
var declaringType = memberExpression.Member.DeclaringType;

if (!declaringType.IsGenericType ||
(declaringType.GetGenericTypeDefinition() != typeof(Dictionary<,>) &&
declaringType.GetGenericTypeDefinition() != typeof(IDictionary<,>)))
{
return false;
}

switch (memberName)
{
case "Keys":
{
var dictionaryExpression = memberExpression.Expression;
translation = ContainsKeyMethodToAggregationExpressionTranslator.TranslateContainsKey(context, expression, dictionaryExpression, valueExpression);
return true;
}
case "Values":
{
var dictionaryExpression = memberExpression.Expression;
translation = ContainsValueMethodToAggregationExpressionTranslator.TranslateContainsValue(context, expression, dictionaryExpression, valueExpression);
return true;
}
default:
return false;
}
}
}
}
Loading