From 601adaab2da041ff7f64ea137cfd8a2f7bc1f18e Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 9 Dec 2025 12:20:33 +0100 Subject: [PATCH 01/23] WIP --- ...osmosProjectionBindingExpressionVisitor.cs | 73 +++- ...ionBindingRemovingExpressionVisitorBase.cs | 243 +++++++---- .../ComplexPropertiesCosmosFixture.cs | 38 ++ .../ComplexPropertiesProjectionCosmosTest.cs | 397 ++++++++++++++++++ 4 files changed, 672 insertions(+), 79 deletions(-) create mode 100644 test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs create mode 100644 test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 72abbf75413..7dbe2f5d703 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; @@ -352,24 +353,67 @@ UnaryExpression unaryExpression Expression NullSafeUpdate(Expression? expression) { - Expression updatedMemberExpression = memberExpression.Update( - expression != null ? MatchTypes(expression, memberExpression.Expression!.Type) : expression); + if (expression is null) + { + return memberExpression.Update(expression); + } + + var expressionValue = Expression.Parameter(expression.Type); + var assignment = Expression.Assign(expressionValue, expression); + + if (expression.Type.IsNullableType() == true + && !memberExpression.Type.IsNullableType() + && memberExpression.Expression is MemberExpression innerMember + && innerMember.Type.IsNullableValueType() == true + && memberExpression.Member.Name == nameof(Nullable<>.Value)) + { + var nullCheck = Expression.Not( + Expression.Property(expressionValue, nameof(Nullable<>.HasValue))); + var conditionalExpression = Expression.Condition( + nullCheck, + Expression.Default(memberExpression.Type), + Expression.Property(expressionValue, nameof(Nullable<>.Value))); + + return Expression.Block( + [expressionValue], + assignment, + conditionalExpression); + } + + Expression updatedMemberExpression = memberExpression.Update(MatchTypes(expressionValue, memberExpression.Expression!.Type)); - if (expression?.Type.IsNullableType() == true) + if (expression.Type.IsNullableType() == true) { var nullableReturnType = memberExpression.Type.MakeNullable(); - if (!memberExpression.Type.IsNullableType()) + + if (!updatedMemberExpression.Type.IsNullableType()) { updatedMemberExpression = Expression.Convert(updatedMemberExpression, nullableReturnType); } + Expression nullCheck; + if (expression.Type.IsNullableValueType()) + { + // For Nullable, use HasValue property instead of equality comparison + // to avoid issues with value types that don't define the == operator + nullCheck = Expression.Not( + Expression.Property(expressionValue, nameof(Nullable<>.HasValue))); + } + else + { + nullCheck = Expression.Equal(expressionValue, Expression.Default(expression.Type)); + } + updatedMemberExpression = Expression.Condition( - Expression.Equal(expression, Expression.Default(expression.Type)), - Expression.Constant(null, nullableReturnType), + nullCheck, + Expression.Default(nullableReturnType), updatedMemberExpression); } - return updatedMemberExpression; + return Expression.Block( + [expressionValue], + assignment, + updatedMemberExpression); } } @@ -639,8 +683,21 @@ UnaryExpression unaryExpression updatedMethodCallExpression = Expression.Convert(updatedMethodCallExpression, nullableReturnType); } + Expression nullCheck; + if (@object.Type.IsNullableValueType()) + { + // For Nullable, use HasValue property instead of equality comparison + // to avoid issues with value types that don't define the == operator + nullCheck = Expression.Not( + Expression.Property(@object, nameof(Nullable<>.HasValue))); + } + else + { + nullCheck = Expression.Equal(@object, Expression.Constant(null, @object.Type)); + } + return Expression.Condition( - Expression.Equal(@object, Expression.Default(@object.Type)), + nullCheck, Expression.Constant(null, nullableReturnType), updatedMethodCallExpression); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index de333edebc8..705d1bec9c7 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -41,32 +41,102 @@ private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo = typeof(IClrCollectionAccessor).GetTypeInfo() .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)); + // [MaterializationContext] = CosmosQueryObject (c or c[Prop]) private readonly IDictionary _materializationContextBindings = new Dictionary(); + // [CosmosQueryObject (c or c[Prop])] = jObject private readonly IDictionary _projectionBindings = new Dictionary(); - private readonly IDictionary _ownerMappings - = new Dictionary(); + // [CosmosQueryObject (c[Prop])] = [(OwnerType (type of c), CosmosQueryObject (c))] + private readonly IDictionary _ownerMappings // @TODO: This can stay IEntityType probably? + = new Dictionary(); + + // [$instance] = (OwnerType (type of c), CosmosQueryObject (c)) + private readonly Dictionary _instanceMaps = new(); + + // [CosmosQueryObject (c or c[Prop])] = [IComplexProperty] = jObject + private readonly Dictionary> _complexMappings = new(); + + // [$entry] = $materializationContext + private readonly Dictionary _entryMaps = new(); private readonly IDictionary _ordinalParameterBindings = new Dictionary(); private List _pendingIncludes = []; - + private int _currentComplexIndex; private static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); + private class MaterializationContextExtractorExpressionVisitor : ExpressionVisitor + { + private ParameterExpression _materializationContextParameter; + + public ParameterExpression Extract(Expression expression) + { + _materializationContextParameter = null; + Visit(expression); + return _materializationContextParameter; + } + + protected override Expression VisitParameter(ParameterExpression node) + { + if (node.Type == typeof(MaterializationContext)) + { + _materializationContextParameter = node; + return node; + } + return base.VisitParameter(node); + } + } + + private readonly Dictionary _structuralTypeBlocks = new(); + + protected override Expression VisitBlock(BlockExpression node) + { + if (node.Variables.Any(x => x.Type == typeof(MaterializationContext)) && + node.Variables.Any(x => x.Type == typeof(IEntityType)) && +#pragma warning disable EF1001 // Internal EF Core API usage. + node.Variables.Any(x => x.Type == typeof(InternalEntityEntry))) +#pragma warning restore EF1001 // Internal EF Core API usage. + + { + _structuralTypeBlocks.Add(node, null); + } + + return base.VisitBlock(node); + } + protected override Expression VisitBinary(BinaryExpression binaryExpression) { if (binaryExpression.NodeType == ExpressionType.Assign) { - if (binaryExpression.Left is ParameterExpression parameterExpression) + + } + + if (binaryExpression.NodeType == ExpressionType.Assign) + { +#pragma warning disable EF1001 // Internal EF Core API usage. + var internalEntryEntityMemberInfo = (MemberInfo)typeof(IInternalEntry).GetProperty(nameof(IInternalEntry.Entity)); +#pragma warning restore EF1001 // Internal EF Core API usage. + if (binaryExpression.Right is MemberExpression { Expression: ParameterExpression entryParameter } entityMember && entityMember.Member == internalEntryEntityMemberInfo && entryParameter.Type.IsAssignableTo(typeof(IUpdateEntry))) { - if (parameterExpression.Type == typeof(JObject) + var materializationContext = _entryMaps[entryParameter]; + var cosmosQueryObject = _materializationContextBindings[materializationContext]; + + _instanceMaps[binaryExpression.Left] = (null, cosmosQueryObject); + } + else if (binaryExpression.Left is ParameterExpression parameterExpression) + { + if (parameterExpression.Type.IsAssignableTo(typeof(IUpdateEntry))) + { + _entryMaps[parameterExpression] = new MaterializationContextExtractorExpressionVisitor().Extract(binaryExpression.Right); + } + else if (parameterExpression.Type == typeof(JObject) || parameterExpression.Type == typeof(JArray)) { string storeName = null; @@ -189,8 +259,28 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) } } - if (binaryExpression.Left is MemberExpression { Member: FieldInfo { IsInitOnly: true } } memberExpression) + if (binaryExpression.Left is MemberExpression memberExpression) { + var instance = _instanceMaps[memberExpression.Expression]; + var complexProperty = instance.StructuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); + + if (complexProperty != null) + { + _currentComplexIndex++; + var complexJObjectVariable = Variable( + typeof(JObject), + "complexJObject" + _currentComplexIndex); + + var assignVariable = Assign(complexJObjectVariable, + Call( + ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), + Call(instance.JObjectExpression, GetItemMethodInfo, + Constant(complexProperty.Name) // @TODO: Get json property name + ) + ) + ); + } + return memberExpression.Assign(Visit(binaryExpression.Right)); } } @@ -220,6 +310,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; } + var declaringType = property.DeclaringType; + if (declaringType is IComplexType complexType) + { + //declaringType = complexType.ComplexProperty.DeclaringType; + // @TODO: Will complexMappings contain all levels of nesting? Or do we need to traverse up? + innerExpression = _complexMappings[innerExpression][complexType.ComplexProperty]; + } + return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); } @@ -631,88 +729,91 @@ private Expression CreateGetValueExpression( return _projectionBindings[jTokenExpression]; } - var entityType = property.DeclaringType as IEntityType; - var ownership = entityType?.FindOwnership(); var storeName = property.GetJsonPropertyName(); - if (storeName.Length == 0) + + if (property.DeclaringType is IEntityType entityType) { - if (entityType == null - || !entityType.IsDocumentRoot()) + var ownership = entityType.FindOwnership(); + if (storeName.Length == 0) { - if (ownership is { IsUnique: false } && property.IsOrdinalKeyProperty()) + if (entityType == null + || !entityType.IsDocumentRoot()) { - var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; - if (ordinalExpression.Type != type) + if (ownership is { IsUnique: false } && property.IsOrdinalKeyProperty()) { - ordinalExpression = Convert(ordinalExpression, type); - } + var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; + if (ordinalExpression.Type != type) + { + ordinalExpression = Convert(ordinalExpression, type); + } - return ordinalExpression; - } + return ordinalExpression; + } - var principalProperty = property.FindFirstPrincipal(); - if (principalProperty != null) - { - Expression ownerJObjectExpression = null; - if (_ownerMappings.TryGetValue(jTokenExpression, out var ownerInfo)) + var principalProperty = property.FindFirstPrincipal(); + if (principalProperty != null) { - Check.DebugAssert( - principalProperty.DeclaringType.IsAssignableFrom(ownerInfo.EntityType), - $"{principalProperty.DeclaringType} is not assignable from {ownerInfo.EntityType}"); + Expression ownerJObjectExpression = null; + if (_ownerMappings.TryGetValue(jTokenExpression, out var ownerInfo)) + { + Check.DebugAssert( + principalProperty.DeclaringType.IsAssignableFrom(ownerInfo.StructuralType), + $"{principalProperty.DeclaringType} is not assignable from {ownerInfo.StructuralType}"); - ownerJObjectExpression = ownerInfo.JObjectExpression; - } - else if (jTokenExpression is ObjectReferenceExpression objectReferenceExpression) - { - ownerJObjectExpression = objectReferenceExpression; - } - else if (jTokenExpression is ObjectAccessExpression objectAccessExpression) - { - ownerJObjectExpression = objectAccessExpression.Object; - } + ownerJObjectExpression = ownerInfo.JObjectExpression; + } + else if (jTokenExpression is ObjectReferenceExpression objectReferenceExpression) + { + ownerJObjectExpression = objectReferenceExpression; + } + else if (jTokenExpression is ObjectAccessExpression objectAccessExpression) + { + ownerJObjectExpression = objectAccessExpression.Object; + } - if (ownerJObjectExpression != null) - { - return CreateGetValueExpression(ownerJObjectExpression, principalProperty, type); + if (ownerJObjectExpression != null) + { + return CreateGetValueExpression(ownerJObjectExpression, principalProperty, type); + } } } - } - - return Default(type); - } - - // Workaround for old databases that didn't store the key property - if (ownership is { IsUnique: false } - && !entityType.IsDocumentRoot() - && property.ClrType == typeof(int) - && !property.IsForeignKey() - && property.FindContainingPrimaryKey() is { Properties.Count: > 1 } - && property.GetJsonPropertyName().Length != 0 - && !property.IsShadowProperty()) - { - var readExpression = CreateGetValueExpression( - jTokenExpression, - storeName, - type.MakeNullable(), - property.GetTypeMapping(), - isNonNullableScalar: false); - var nonNullReadExpression = readExpression; - if (nonNullReadExpression.Type != type) - { - nonNullReadExpression = Convert(nonNullReadExpression, type); + return Default(type); } - var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; - if (ordinalExpression.Type != type) + // Workaround for old databases that didn't store the key property + if (ownership is { IsUnique: false } + && !entityType.IsDocumentRoot() + && property.ClrType == typeof(int) + && !property.IsForeignKey() + && property.FindContainingPrimaryKey() is { Properties.Count: > 1 } + && property.GetJsonPropertyName().Length != 0 + && !property.IsShadowProperty()) { - ordinalExpression = Convert(ordinalExpression, type); - } + var readExpression = CreateGetValueExpression( + jTokenExpression, + storeName, + type.MakeNullable(), + property.GetTypeMapping(), + isNonNullableScalar: false); + + var nonNullReadExpression = readExpression; + if (nonNullReadExpression.Type != type) + { + nonNullReadExpression = Convert(nonNullReadExpression, type); + } - return Condition( - Equal(readExpression, Constant(null, readExpression.Type)), - ordinalExpression, - nonNullReadExpression); + var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; + if (ordinalExpression.Type != type) + { + ordinalExpression = Convert(ordinalExpression, type); + } + + return Condition( + Equal(readExpression, Constant(null, readExpression.Type)), + ordinalExpression, + nonNullReadExpression); + } } return Convert( diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs new file mode 100644 index 00000000000..dbfbd64ea0a --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.Associations.ComplexProperties; + +public class ComplexPropertiesCosmosFixture : ComplexPropertiesFixtureBase +{ + public TestSqlLoggerFactory TestSqlLoggerFactory + => (TestSqlLoggerFactory)ListLoggerFactory; + + protected override ITestStoreFactory TestStoreFactory + => CosmosTestStoreFactory.Instance; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + => base.AddOptions(builder) + .ConfigureWarnings(w => w.Ignore(CosmosEventId.NoPartitionKeyDefined).Ignore(CoreEventId.MappedEntityTypeIgnoredWarning)); + + public Task NoSyncTest(bool async, Func testCode) + => CosmosTestHelpers.Instance.NoSyncTest(async, testCode); + + public void NoSyncTest(Action testCode) + => CosmosTestHelpers.Instance.NoSyncTest(testCode); + + protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context) + { + base.OnModelCreating(modelBuilder, context); + + modelBuilder.Ignore(); + + modelBuilder.Entity() + .ToContainer("RootEntities") + .HasNoDiscriminator(); + + modelBuilder.Entity() + .ToContainer("ValueRootEntities") + .HasNoDiscriminator(); + } +} diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs new file mode 100644 index 00000000000..fc9d8af5e10 --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -0,0 +1,397 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.Associations.ComplexProperties; + +public class ComplexPropertiesProjectionCosmosTest : ComplexPropertiesProjectionTestBase +{ + public ComplexPropertiesProjectionCosmosTest(ComplexPropertiesCosmosFixture fixture, ITestOutputHelper outputHelper) : base(fixture) + { + Fixture.TestSqlLoggerFactory.Clear(); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(outputHelper); + } + + public override async Task Select_root(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_root(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + + #region Scalar properties + + public override async Task Select_scalar_property_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_scalar_property_on_required_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c["RequiredAssociate"]["String"] +FROM root c +"""); + } + + public override async Task Select_property_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the + // result to be filtered out entirely. + await AssertQuery( + ss => ss.Set().Select(x => x.OptionalAssociate!.String), + ss => ss.Set().Where(x => x.OptionalAssociate != null).Select(x => x.OptionalAssociate!.String), + queryTrackingBehavior: queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c["OptionalAssociate"]["String"] +FROM root c +"""); + } + + public override async Task Select_value_type_property_on_null_associate_throws(QueryTrackingBehavior queryTrackingBehavior) + { + // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the + // result to be filtered out entirely. + await AssertQuery( + ss => ss.Set().Select(x => x.OptionalAssociate!.Int), + ss => ss.Set().Where(x => x.OptionalAssociate != null).Select(x => x.OptionalAssociate!.Int), + queryTrackingBehavior: queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c["OptionalAssociate"]["Int"] +FROM root c +"""); + } + + public override async Task Select_nullable_value_type_property_on_null_associate(QueryTrackingBehavior queryTrackingBehavior) + { + // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the + // result to be filtered out entirely. + await AssertQuery( + ss => ss.Set().Select(x => (int?)x.OptionalAssociate!.Int), + ss => ss.Set().Where(x => x.OptionalAssociate != null).Select(x => (int?)x.OptionalAssociate!.Int), + queryTrackingBehavior: queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c["OptionalAssociate"]["Int"] +FROM root c +"""); + } + + #endregion Scalar properties + + #region Structural properties + + public override async Task Select_associate(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_optional_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_required_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_required_nested_on_required_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_optional_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_optional_nested_on_required_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_required_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.Select_required_nested_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_optional_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.Select_optional_nested_on_optional_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + } + + public override Task Select_required_associate_via_optional_navigation(QueryTrackingBehavior queryTrackingBehavior) + // We don't support (inter-document) navigations with Cosmos. + => Assert.ThrowsAsync(() => base.Select_required_associate_via_optional_navigation(queryTrackingBehavior)); + + public override async Task Select_unmapped_associate_scalar_property(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_unmapped_associate_scalar_property(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + } + + public override async Task Select_untranslatable_method_on_associate_scalar_property(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_untranslatable_method_on_associate_scalar_property(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c["RequiredAssociate"]["Int"] +FROM root c +"""); + } + + #endregion Structural properties + + #region Structural collection properties + + public override async Task Select_associate_collection(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_associate_collection(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + } + + public override async Task Select_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.Select_nested_collection_on_required_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + } + + public override async Task Select_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.Select_nested_collection_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + } + + public override async Task SelectMany_associate_collection(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.SelectMany_associate_collection(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE a +FROM root c +JOIN a IN c["AssociateCollection"] +"""); + } + } + + public override async Task SelectMany_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await base.SelectMany_nested_collection_on_required_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE n +FROM root c +JOIN n IN c["RequiredAssociate"]["NestedCollection"] +"""); + } + } + + public override async Task SelectMany_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + // The given key 'n' was not present in the dictionary + await base.SelectMany_nested_collection_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE n +FROM root c +JOIN n IN c["OptionalAssociate"]["NestedCollection"] +"""); + } + } + + #endregion Structural collection properties + + #region Multiple + + public override async Task Select_root_duplicated(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_root_duplicated(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + + #endregion Multiple + + #region Subquery + + public override async Task Select_subquery_required_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + } + } + + public override async Task Select_subquery_optional_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) + { + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + } + } + + #endregion Subquery + + #region Value types + + public override async Task Select_root_with_value_types(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_root_with_value_types(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +"""); + } + + public override async Task Select_non_nullable_value_type(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_non_nullable_value_type(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + + + public override async Task Select_nullable_value_type(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_nullable_value_type(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + + public override async Task Select_nullable_value_type_with_Value(QueryTrackingBehavior queryTrackingBehavior) + { + await base.Select_nullable_value_type_with_Value(queryTrackingBehavior); + + AssertSql( + """ +SELECT VALUE c +FROM root c +ORDER BY c["Id"] +"""); + } + + #endregion Value types + + + [ConditionalFact] + public virtual void Check_all_tests_overridden() + => TestHelpers.AssertAllMethodsOverridden(GetType()); + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); +} From 072889df5cdfbea000f61035756083fda4c672e0 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:07:34 +0100 Subject: [PATCH 02/23] WIP --- ...ionBindingRemovingExpressionVisitorBase.cs | 563 +++++------------- 1 file changed, 153 insertions(+), 410 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 705d1bec9c7..3990effde49 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -#nullable disable - using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; @@ -15,16 +13,11 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { - private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase( - ParameterExpression jTokenParameter, - bool trackQueryResults) + // Removes assignments of MaterializationContext + // Rewrites usages of MaterializationContext to use JObject variable injected by JObjectInjectingExpressionVisitor instead. + private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase(ParameterExpression jTokenParameter, bool trackQueryResults) : ExpressionVisitor { - private static readonly MethodInfo GetItemMethodInfo - = typeof(JObject).GetRuntimeProperties() - .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) - .GetMethod; - private static readonly PropertyInfo JTokenTypePropertyInfo = typeof(JToken).GetRuntimeProperties() .Single(mi => mi.Name == nameof(JToken.Type)); @@ -33,223 +26,31 @@ private static readonly MethodInfo JTokenToObjectWithSerializerMethodInfo = typeof(JToken).GetRuntimeMethods() .Single(mi => mi.Name == nameof(JToken.ToObject) && mi.GetParameters().Length == 1 && mi.IsGenericMethodDefinition); - private static readonly MethodInfo CollectionAccessorAddMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)); - - private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)); - - // [MaterializationContext] = CosmosQueryObject (c or c[Prop]) - private readonly IDictionary _materializationContextBindings - = new Dictionary(); - - // [CosmosQueryObject (c or c[Prop])] = jObject - private readonly IDictionary _projectionBindings - = new Dictionary(); - - // [CosmosQueryObject (c[Prop])] = [(OwnerType (type of c), CosmosQueryObject (c))] - private readonly IDictionary _ownerMappings // @TODO: This can stay IEntityType probably? - = new Dictionary(); - - // [$instance] = (OwnerType (type of c), CosmosQueryObject (c)) - private readonly Dictionary _instanceMaps = new(); - - // [CosmosQueryObject (c or c[Prop])] = [IComplexProperty] = jObject - private readonly Dictionary> _complexMappings = new(); - - // [$entry] = $materializationContext - private readonly Dictionary _entryMaps = new(); - - private readonly IDictionary _ordinalParameterBindings - = new Dictionary(); + private static readonly MethodInfo GetItemMethodInfo + = typeof(JToken).GetRuntimeProperties() + .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object)) + .GetMethod!; - private List _pendingIncludes - = []; - private int _currentComplexIndex; private static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); - private class MaterializationContextExtractorExpressionVisitor : ExpressionVisitor - { - private ParameterExpression _materializationContextParameter; - - public ParameterExpression Extract(Expression expression) - { - _materializationContextParameter = null; - Visit(expression); - return _materializationContextParameter; - } - - protected override Expression VisitParameter(ParameterExpression node) - { - if (node.Type == typeof(MaterializationContext)) - { - _materializationContextParameter = node; - return node; - } - return base.VisitParameter(node); - } - } - - private readonly Dictionary _structuralTypeBlocks = new(); - - protected override Expression VisitBlock(BlockExpression node) - { - if (node.Variables.Any(x => x.Type == typeof(MaterializationContext)) && - node.Variables.Any(x => x.Type == typeof(IEntityType)) && -#pragma warning disable EF1001 // Internal EF Core API usage. - node.Variables.Any(x => x.Type == typeof(InternalEntityEntry))) -#pragma warning restore EF1001 // Internal EF Core API usage. - - { - _structuralTypeBlocks.Add(node, null); - } - - return base.VisitBlock(node); - } + private ParameterExpression? _entityTypeBlockJObject; + private ConcreteStructuralTypeBlock? _concreteStructuralTypeBlock; + private List _pendingIncludes = []; + private int _currentComplexIndex = 1; protected override Expression VisitBinary(BinaryExpression binaryExpression) { if (binaryExpression.NodeType == ExpressionType.Assign) { - - } - - if (binaryExpression.NodeType == ExpressionType.Assign) - { -#pragma warning disable EF1001 // Internal EF Core API usage. - var internalEntryEntityMemberInfo = (MemberInfo)typeof(IInternalEntry).GetProperty(nameof(IInternalEntry.Entity)); -#pragma warning restore EF1001 // Internal EF Core API usage. - if (binaryExpression.Right is MemberExpression { Expression: ParameterExpression entryParameter } entityMember && entityMember.Member == internalEntryEntityMemberInfo && entryParameter.Type.IsAssignableTo(typeof(IUpdateEntry))) + if (binaryExpression.Left is ParameterExpression parameterExpression) { - var materializationContext = _entryMaps[entryParameter]; - var cosmosQueryObject = _materializationContextBindings[materializationContext]; - - _instanceMaps[binaryExpression.Left] = (null, cosmosQueryObject); - } - else if (binaryExpression.Left is ParameterExpression parameterExpression) - { - if (parameterExpression.Type.IsAssignableTo(typeof(IUpdateEntry))) - { - _entryMaps[parameterExpression] = new MaterializationContextExtractorExpressionVisitor().Extract(binaryExpression.Right); - } - else if (parameterExpression.Type == typeof(JObject) - || parameterExpression.Type == typeof(JArray)) - { - string storeName = null; - - // Values injected by JObjectInjectingExpressionVisitor - var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; - - if (projectionExpression is UnaryExpression - { - NodeType: ExpressionType.Convert, - Operand: UnaryExpression operand - }) - { - // Unwrap EntityProjectionExpression when the root entity is not projected - // That is, this is handling the projection of a non-root entity type. - projectionExpression = operand.Operand; - } - - switch (projectionExpression) - { - // ProjectionBindingExpression may represent a named token to be obtained from a containing JObject, or - // it may be that the token is not nested in a JObject if the query was generated using the SQL VALUE clause. - case ProjectionBindingExpression projectionBindingExpression: - { - var projection = GetProjection(projectionBindingExpression); - projectionExpression = projection.Expression; - if (!projection.IsValueProjection) - { - storeName = projection.Alias; - } - - break; - } - - case ObjectArrayAccessExpression e: - storeName = e.PropertyName; - break; - - case EntityProjectionExpression e: - storeName = e.PropertyName; - break; - } - - Expression valueExpression; - switch (projectionExpression) - { - case ObjectArrayAccessExpression objectArrayProjectionExpression: - _projectionBindings[objectArrayProjectionExpression] = parameterExpression; - valueExpression = CreateGetValueExpression( - objectArrayProjectionExpression.Object, storeName, parameterExpression.Type); - break; - - case EntityProjectionExpression entityProjectionExpression: - var accessExpression = entityProjectionExpression.Object; - _projectionBindings[accessExpression] = parameterExpression; - - switch (accessExpression) - { - case ObjectReferenceExpression: - valueExpression = CreateGetValueExpression(jTokenParameter, storeName, parameterExpression.Type); - break; - - case ObjectAccessExpression: - // Access to an owned type may be nested inside another owned type, so collect the store names - // and add owner mappings for each. - var storeNames = new List(); - while (accessExpression is ObjectAccessExpression objectAccessExpression) - { - accessExpression = objectAccessExpression.Object; - storeNames.Add(objectAccessExpression.PropertyName); - _ownerMappings[objectAccessExpression] - = (objectAccessExpression.Navigation.DeclaringEntityType, accessExpression); - } - - valueExpression = CreateGetValueExpression(accessExpression, (string)null, typeof(JObject)); - for (var i = storeNames.Count - 1; i >= 0; i--) - { - valueExpression = CreateGetValueExpression(valueExpression, storeNames[i], typeof(JObject)); - } - - break; - default: - throw new InvalidOperationException( - CoreStrings.TranslationFailed(binaryExpression.Print())); - } - - break; - - default: - throw new UnreachableException(); - } - - return MakeBinary(ExpressionType.Assign, binaryExpression.Left, valueExpression); - } - + // Overwrite any creations of MaterializationContext if (parameterExpression.Type == typeof(MaterializationContext)) { var newExpression = (NewExpression)binaryExpression.Right; - - EntityProjectionExpression entityProjectionExpression; - if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) - { - var projection = GetProjection(projectionBindingExpression); - entityProjectionExpression = (EntityProjectionExpression)projection.Expression; - } - else - { - var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; - entityProjectionExpression = (EntityProjectionExpression)projection; - } - - _materializationContextBindings[parameterExpression] = entityProjectionExpression.Object; - + Debug.Assert(newExpression.Constructor != null, "Materialization assignment must always be via constructor"); var updatedExpression = New( newExpression.Constructor, Constant(ValueBuffer.Empty), @@ -258,27 +59,13 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); } } - - if (binaryExpression.Left is MemberExpression memberExpression) + else if (binaryExpression.Left is MemberExpression memberExpression) { - var instance = _instanceMaps[memberExpression.Expression]; - var complexProperty = instance.StructuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); - + Debug.Assert(_concreteStructuralTypeBlock != null, "Assignments to properties can only happen inside a structural type block."); + var complexProperty = _concreteStructuralTypeBlock.StructuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); if (complexProperty != null) { - _currentComplexIndex++; - var complexJObjectVariable = Variable( - typeof(JObject), - "complexJObject" + _currentComplexIndex); - - var assignVariable = Assign(complexJObjectVariable, - Call( - ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), - Call(instance.JObjectExpression, GetItemMethodInfo, - Constant(complexProperty.Name) // @TODO: Get json property name - ) - ) - ); + return CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty); } return memberExpression.Assign(Visit(binaryExpression.Right)); @@ -288,37 +75,60 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return base.VisitBinary(binaryExpression); } + private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, Expression valueExpression, IComplexProperty complexProperty) + { + Debug.Assert(_concreteStructuralTypeBlock != null, "Complex property assignments can only happen inside a structural type block."); + + var complexJObjectVariableExpression = Variable( + typeof(JObject), + "complexJObject" + _currentComplexIndex++); + var assignComplexJObjectVariableExpression = Assign(complexJObjectVariableExpression, Call( // @TODO: Can we reuse get property value? + ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), + Call(_concreteStructuralTypeBlock.JObject, GetItemMethodInfo, + Constant(complexProperty.Name) + ) + )); + + if (complexProperty.IsNullable) + { + var condition = (ConditionalExpression)valueExpression; + valueExpression = Condition( + Equal(complexJObjectVariableExpression, Constant(null)), + condition.IfTrue, + condition.IfFalse); + } + + valueExpression = EnterScope(ref _concreteStructuralTypeBlock, new ConcreteStructuralTypeBlock(complexJObjectVariableExpression, complexProperty.ComplexType), + () => Visit(valueExpression)); + + return Block( + [complexJObjectVariableExpression], + assignComplexJObjectVariableExpression, + memberExpression.Assign(valueExpression) + ); + } + + /// + /// Overwrites usages of MaterializationContext to get property values from JObject + /// Handles IncludeExpressions to track included entities + /// protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { var method = methodCallExpression.Method; var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; + + // Use jObject instead of MaterializationContext to get property values if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) { var property = methodCallExpression.Arguments[2].GetConstantValue(); - Expression innerExpression; if (methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) { + //@TODO: When is this needed? var projection = GetProjection(projectionBindingExpression); - - innerExpression = Convert( - CreateReadJTokenExpression(jTokenParameter, projection.Alias), - typeof(JObject)); - } - else - { - innerExpression = _materializationContextBindings[ - (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; - } - - var declaringType = property.DeclaringType; - if (declaringType is IComplexType complexType) - { - //declaringType = complexType.ComplexProperty.DeclaringType; - // @TODO: Will complexMappings contain all levels of nesting? Or do we need to traverse up? - innerExpression = _complexMappings[innerExpression][complexType.ComplexProperty]; + return CreateGetJTokenExpression(jTokenParameter, projection.Alias); } - return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); + return CreateGetValueExpression(property, methodCallExpression.Type); } if (method.DeclaringType == typeof(Enumerable) @@ -336,6 +146,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Navigation)); } + if (trackQueryResults) + { + + } _pendingIncludes.Add(includeExpression); Visit(includeExpression.EntityExpression); @@ -348,6 +162,57 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return base.VisitMethodCall(methodCallExpression); } + #region Context + protected override Expression VisitBlock(BlockExpression blockExpression) + { + var jObject = blockExpression.Variables.SingleOrDefault(x => x.Type == typeof(JObject)); + if (jObject != null) + { + return EnterScope(ref _entityTypeBlockJObject, jObject, () => base.VisitBlock(blockExpression)); + } + + return base.VisitBlock(blockExpression); + } + + protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) + { + if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression + && constantExpression.Value is ITypeBase structuralType) + { + Debug.Assert(_entityTypeBlockJObject != null, "Concrete structural type swith case can not be outside of an entity type block."); + var jObjectVariable = _entityTypeBlockJObject; + return EnterScope(ref _concreteStructuralTypeBlock, new ConcreteStructuralTypeBlock(jObjectVariable, structuralType), () => base.VisitSwitchCase(switchCaseExpression)); + } + + return base.VisitSwitchCase(switchCaseExpression); + } + + private class ConcreteStructuralTypeBlock + { + public ConcreteStructuralTypeBlock(ParameterExpression jObject, ITypeBase structuralType) + { + JObject = jObject; + StructuralType = structuralType; + } + + public ITypeBase StructuralType { get; } + + public ParameterExpression JObject { get; } + } + + private static TReturn EnterScope(ref TScope scope, TScope newValue, Func action) + { + var oldValue = scope; +#pragma warning disable IDE0059 // Unnecessary assignment of a value + scope = newValue; +#pragma warning restore IDE0059 // Unnecessary assignment of a value + var result = action(); + scope = oldValue; + return result; + } + #endregion + + #region Include protected override Expression VisitExtension(Expression extensionExpression) { switch (extensionExpression) @@ -356,10 +221,10 @@ protected override Expression VisitExtension(Expression extensionExpression) { var projection = GetProjection(projectionBindingExpression); - return CreateGetValueExpression( - jTokenParameter, + return CreateGetValueExpression(jTokenParameter, projection.IsValueProjection ? null : projection.Alias, projectionBindingExpression.Type, + false, (projection.Expression as SqlExpression)?.TypeMapping); } @@ -379,16 +244,11 @@ protected override Expression VisitExtension(Expression extensionExpression) throw new InvalidOperationException(CoreStrings.TranslationFailed(extensionExpression.Print())); } - var jArray = _projectionBindings[objectArrayAccess]; + var jObjectParameter = Parameter(typeof(JObject), jArray.Name + "Object"); var ordinalParameter = Parameter(typeof(int), jArray.Name + "Ordinal"); var accessExpression = objectArrayAccess.InnerProjection.Object; - _projectionBindings[accessExpression] = jObjectParameter; - _ownerMappings[accessExpression] = - (objectArrayAccess.Navigation.DeclaringEntityType, objectArrayAccess.Object); - _ordinalParameterBindings[accessExpression] = Add( - ordinalParameter, Constant(1, typeof(int))); var innerShaper = (BlockExpression)Visit(collectionShaperExpression.InnerShaper); @@ -425,7 +285,7 @@ protected override Expression VisitExtension(Expression extensionExpression) if (!isFirstInclude) { - return jObjectBlock; + return jObjectBlock!; } Check.DebugAssert(jObjectBlock != null, "The first include must end up on a valid shaper block"); @@ -521,7 +381,7 @@ private void AddInclude( private static readonly MethodInfo IncludeReferenceMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeReference)); + .GetDeclaredMethod(nameof(IncludeReference))!; private static void IncludeReference( #pragma warning disable EF1001 // Internal EF Core API usage. @@ -566,7 +426,7 @@ private static void IncludeReference( private static readonly MethodInfo IncludeCollectionMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeCollection)); + .GetDeclaredMethod(nameof(IncludeCollection))!; private static void IncludeCollection( #pragma warning disable EF1001 // Internal EF Core API usage. @@ -655,7 +515,7 @@ private static Delegate GenerateFixup( .Compile(); } - private static Delegate GenerateInitialize( + private static Delegate? GenerateInitialize( Type entityType, INavigation navigation) { @@ -695,7 +555,7 @@ private static Expression AddToCollectionNavigation( private static readonly MethodInfo PopulateCollectionMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(PopulateCollection)); + .GetDeclaredMethod(nameof(PopulateCollection))!; private static readonly MethodInfo IsAssignableFromMethodInfo = typeof(IReadOnlyEntityType).GetMethod(nameof(IReadOnlyEntityType.IsAssignableFrom), [typeof(IReadOnlyEntityType)])!; @@ -713,155 +573,22 @@ private static TCollection PopulateCollection( return (TCollection)collection; } + #endregion protected abstract ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression); - private static Expression CreateReadJTokenExpression(Expression jObjectExpression, string propertyName) - => Call(jObjectExpression, GetItemMethodInfo, Constant(propertyName)); - - private Expression CreateGetValueExpression( - Expression jTokenExpression, - IProperty property, - Type type) + /// + /// Create expression to get a property's value from JObject + /// + private Expression CreateGetValueExpression(IProperty property, Type? type = null) { - if (property.Name == CosmosPartitionKeyInPrimaryKeyConvention.JObjectPropertyName) - { - return _projectionBindings[jTokenExpression]; - } - - var storeName = property.GetJsonPropertyName(); - - if (property.DeclaringType is IEntityType entityType) - { - var ownership = entityType.FindOwnership(); - if (storeName.Length == 0) - { - if (entityType == null - || !entityType.IsDocumentRoot()) - { - if (ownership is { IsUnique: false } && property.IsOrdinalKeyProperty()) - { - var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; - if (ordinalExpression.Type != type) - { - ordinalExpression = Convert(ordinalExpression, type); - } - - return ordinalExpression; - } - - var principalProperty = property.FindFirstPrincipal(); - if (principalProperty != null) - { - Expression ownerJObjectExpression = null; - if (_ownerMappings.TryGetValue(jTokenExpression, out var ownerInfo)) - { - Check.DebugAssert( - principalProperty.DeclaringType.IsAssignableFrom(ownerInfo.StructuralType), - $"{principalProperty.DeclaringType} is not assignable from {ownerInfo.StructuralType}"); - - ownerJObjectExpression = ownerInfo.JObjectExpression; - } - else if (jTokenExpression is ObjectReferenceExpression objectReferenceExpression) - { - ownerJObjectExpression = objectReferenceExpression; - } - else if (jTokenExpression is ObjectAccessExpression objectAccessExpression) - { - ownerJObjectExpression = objectAccessExpression.Object; - } - - if (ownerJObjectExpression != null) - { - return CreateGetValueExpression(ownerJObjectExpression, principalProperty, type); - } - } - } - - return Default(type); - } - - // Workaround for old databases that didn't store the key property - if (ownership is { IsUnique: false } - && !entityType.IsDocumentRoot() - && property.ClrType == typeof(int) - && !property.IsForeignKey() - && property.FindContainingPrimaryKey() is { Properties.Count: > 1 } - && property.GetJsonPropertyName().Length != 0 - && !property.IsShadowProperty()) - { - var readExpression = CreateGetValueExpression( - jTokenExpression, - storeName, - type.MakeNullable(), - property.GetTypeMapping(), - isNonNullableScalar: false); - - var nonNullReadExpression = readExpression; - if (nonNullReadExpression.Type != type) - { - nonNullReadExpression = Convert(nonNullReadExpression, type); - } - - var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; - if (ordinalExpression.Type != type) - { - ordinalExpression = Convert(ordinalExpression, type); - } - - return Condition( - Equal(readExpression, Constant(null, readExpression.Type)), - ordinalExpression, - nonNullReadExpression); - } - } - - return Convert( - CreateGetValueExpression( - jTokenExpression, - storeName, - type.MakeNullable(), - property.GetTypeMapping(), - // special case keys - we check them for null to see if the entity needs to be materialized, so we want to keep the null, rather than non-nullable default - // returning defaults is supposed to help with evolving the schema - so this doesn't concern keys anyway (they shouldn't evolve) - isNonNullableScalar: !property.IsNullable && !property.IsKey()), - type); + Debug.Assert(_concreteStructuralTypeBlock != null, "Property value retrieval can only happen inside a structural type block."); + return CreateGetValueExpression(_concreteStructuralTypeBlock.JObject, property.GetJsonPropertyName(), type ?? property.ClrType, !property.IsNullable && !property.IsKey(), property.GetTypeMapping()); } - private Expression CreateGetValueExpression( - Expression jTokenExpression, - string storeName, - Type type, - CoreTypeMapping typeMapping = null, - bool isNonNullableScalar = false) + private Expression CreateGetValueExpression(ParameterExpression jObject, string? property, Type type, bool isNonNullableScalar, CoreTypeMapping? typeMapping) { - Check.DebugAssert(type.IsNullableType(), "Must read nullable type from JObject."); - - var innerExpression = jTokenExpression switch - { - _ when _projectionBindings.TryGetValue(jTokenExpression, out var innerVariable) - => innerVariable, - - ObjectReferenceExpression - => jTokenParameter, - - ObjectAccessExpression objectAccessExpression - => CreateGetValueExpression( - objectAccessExpression.Object, - ((IAccessExpression)objectAccessExpression.Object).PropertyName, - typeof(JObject)), - - _ => jTokenExpression - }; - - jTokenExpression = storeName == null - ? innerExpression - : CreateReadJTokenExpression( - innerExpression.Type == typeof(JObject) - ? innerExpression - : Convert(innerExpression, typeof(JObject)), storeName); - - Expression valueExpression; + var valueExpression = property != null ? CreateGetJTokenExpression(jObject, property) : jObject; var converter = typeMapping?.Converter; if (converter != null) { @@ -913,12 +640,12 @@ var body replaceExpression, body); - valueExpression = Invoke(Lambda(body, jTokenParameter), jTokenExpression); + valueExpression = Invoke(Lambda(body, jTokenParameter), valueExpression); } else { - valueExpression = ConvertJTokenToType( - jTokenExpression, + valueExpression = CreateSerializeJTokenToTypeExpression( + valueExpression, (isNonNullableScalar ? typeMapping?.ClrType : typeMapping?.ClrType.MakeNullable()) @@ -933,14 +660,30 @@ var body return valueExpression; } - private static Expression ConvertJTokenToType(Expression jTokenExpression, Type type) + /// + /// Create expression to get the JToken for a property from JObject + /// + private Expression CreateGetJTokenExpression(ParameterExpression jObject, IPropertyBase propertyBase) + => CreateGetJTokenExpression(jObject, propertyBase is IProperty p ? p.GetJsonPropertyName() : propertyBase.Name); + + /// + /// Create expression to get the JToken for a property from JObject + /// + private Expression CreateGetJTokenExpression(ParameterExpression jObject, string propertyName) + => Call(jObject, GetItemMethodInfo, + Constant(propertyName)); + + /// + /// Create expression to serialize JToken to given type + /// + private static Expression CreateSerializeJTokenToTypeExpression(Expression jTokenExpression, Type type) => type == typeof(JToken) ? jTokenExpression : Call( ToObjectWithSerializerMethodInfo.MakeGenericMethod(type), jTokenExpression); - private static T SafeToObjectWithSerializer(JToken token) + private static T? SafeToObjectWithSerializer(JToken? token) => token == null || token.Type == JTokenType.Null ? default : token.ToObject(CosmosClientWrapper.Serializer); } } From 35ef5c7531b47f3e46548c5f445f762a6bff5ba6 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 10:52:02 +0100 Subject: [PATCH 03/23] WIP --- ...ionBindingRemovingExpressionVisitorBase.cs | 140 +++++++++++++----- 1 file changed, 100 insertions(+), 40 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 3990effde49..3ee2e189916 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -29,12 +29,13 @@ private static readonly MethodInfo JTokenToObjectWithSerializerMethodInfo private static readonly MethodInfo GetItemMethodInfo = typeof(JToken).GetRuntimeProperties() .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object)) - .GetMethod!; + .GetMethod ?? throw new UnreachableException(); private static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); + private ParameterExpression? _collectionBlockJArray; private ParameterExpression? _entityTypeBlockJObject; private ConcreteStructuralTypeBlock? _concreteStructuralTypeBlock; private List _pendingIncludes = []; @@ -46,6 +47,25 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { if (binaryExpression.Left is ParameterExpression parameterExpression) { + if (parameterExpression.Type == typeof(JObject) || + parameterExpression.Type == typeof(JArray)) + { + var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; + + if (projectionExpression is UnaryExpression + { + NodeType: ExpressionType.Convert, + Operand: UnaryExpression operand + }) + { + // Unwrap EntityProjectionExpression when the root entity is not projected + // That is, this is handling the projection of a non-root entity type. + projectionExpression = operand.Operand; + } + + return MakeBinary(binaryExpression.NodeType, binaryExpression.Left, Visit(projectionExpression)); + } + // Overwrite any creations of MaterializationContext if (parameterExpression.Type == typeof(MaterializationContext)) { @@ -123,12 +143,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var property = methodCallExpression.Arguments[2].GetConstantValue(); if (methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) { - //@TODO: When is this needed? var projection = GetProjection(projectionBindingExpression); return CreateGetJTokenExpression(jTokenParameter, projection.Alias); } - return CreateGetValueExpression(property, methodCallExpression.Type); + return CreateGetValueExpression(property, method.ReturnType); } if (method.DeclaringType == typeof(Enumerable) @@ -165,15 +184,29 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp #region Context protected override Expression VisitBlock(BlockExpression blockExpression) { - var jObject = blockExpression.Variables.SingleOrDefault(x => x.Type == typeof(JObject)); - if (jObject != null) + var param = blockExpression.Variables.Count == 1 ? blockExpression.Variables[0] : null; + if (param?.Type == typeof(JObject)) { - return EnterScope(ref _entityTypeBlockJObject, jObject, () => base.VisitBlock(blockExpression)); + return EnterScope(ref _entityTypeBlockJObject, param, () => base.VisitBlock(blockExpression)); + } + + if (param?.Type == typeof(JArray)) + { + return EnterScope(ref _collectionBlockJArray, param, () => base.VisitBlock(blockExpression)); } return base.VisitBlock(blockExpression); } + //protected override Expression VisitLambda(Expression node) + //{ + // if (node.Parameters.FirstOrDefault(x => x.Type == typeof(JObject)) is ParameterExpression jObject) + // { + // return EnterScope(ref _entityTypeBlockJObject, jObject, () => base.VisitLambda(node)); + // } + // return base.VisitLambda(node); + //} + protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) { if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression @@ -223,34 +256,41 @@ protected override Expression VisitExtension(Expression extensionExpression) return CreateGetValueExpression(jTokenParameter, projection.IsValueProjection ? null : projection.Alias, - projectionBindingExpression.Type, + typeof(JObject), false, (projection.Expression as SqlExpression)?.TypeMapping); } - case CollectionShaperExpression collectionShaperExpression: + case ObjectArrayAccessExpression objectArrayAccessExpression: { - ObjectArrayAccessExpression objectArrayAccess; - switch (collectionShaperExpression.Projection) - { - case ProjectionBindingExpression projectionBindingExpression: - var projection = GetProjection(projectionBindingExpression); - objectArrayAccess = (ObjectArrayAccessExpression)projection.Expression; - break; - case ObjectArrayAccessExpression objectArrayProjectionExpression: - objectArrayAccess = objectArrayProjectionExpression; - break; - default: - throw new InvalidOperationException(CoreStrings.TranslationFailed(extensionExpression.Print())); - } + return CreateGetValueExpression( + _entityTypeBlockJObject ?? throw new InvalidOperationException(), + objectArrayAccessExpression.PropertyName, + objectArrayAccessExpression.Type, + false, + null); + } + case EntityProjectionExpression entityProjectionExpression: + { + Debug.Assert(_entityTypeBlockJObject != null, "Entity projection can only be inside an entity type block."); + return CreateGetValueExpression( + _entityTypeBlockJObject ?? throw new InvalidOperationException(), + entityProjectionExpression.PropertyName, + entityProjectionExpression.Type, + false, + null); + } - var jObjectParameter = Parameter(typeof(JObject), jArray.Name + "Object"); - var ordinalParameter = Parameter(typeof(int), jArray.Name + "Ordinal"); + case CollectionShaperExpression collectionShaperExpression: + { + Debug.Assert(collectionShaperExpression.Navigation != null); + Debug.Assert(_collectionBlockJArray != null, "Collection shaper can only be inside a collection block."); - var accessExpression = objectArrayAccess.InnerProjection.Object; + var jObjectParameter = Parameter(typeof(JObject), _collectionBlockJArray.Name + "Object"); + var ordinalParameter = Parameter(typeof(int), _collectionBlockJArray.Name + "Ordinal"); - var innerShaper = (BlockExpression)Visit(collectionShaperExpression.InnerShaper); + var innerShaper = EnterScope(ref _entityTypeBlockJObject, jObjectParameter, () => (BlockExpression)Visit(collectionShaperExpression.InnerShaper)); innerShaper = AddIncludes(innerShaper); @@ -258,7 +298,7 @@ protected override Expression VisitExtension(Expression extensionExpression) EnumerableMethods.SelectWithOrdinal.MakeGenericMethod(typeof(JObject), innerShaper.Type), Call( EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), - jArray), + _collectionBlockJArray), Lambda(innerShaper, jObjectParameter, ordinalParameter)); var navigation = collectionShaperExpression.Navigation; @@ -287,22 +327,25 @@ protected override Expression VisitExtension(Expression extensionExpression) { return jObjectBlock!; } - Check.DebugAssert(jObjectBlock != null, "The first include must end up on a valid shaper block"); - // These are the expressions added by JObjectInjectingExpressionVisitor - var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[^1]; + var jObjectParameter = jObjectBlock.Variables.Single(); + return EnterScope(ref _entityTypeBlockJObject, jObjectParameter, () => + { + // These are the expressions added by JObjectInjectingExpressionVisitor + var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[^1]; - var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; - shaperBlock = AddIncludes(shaperBlock); + var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; + shaperBlock = AddIncludes(shaperBlock); - var jObjectExpressions = new List(jObjectBlock.Expressions); - jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); + var jObjectExpressions = new List(jObjectBlock.Expressions); + jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); - jObjectExpressions.Add( - jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); + jObjectExpressions.Add( + jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); - return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); + return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); + }); } } @@ -457,7 +500,7 @@ private static void IncludeCollection( foreach (var relatedEntity in relatedEntities) { fixup(includingEntity, relatedEntity); - inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity); + inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity!); } } else @@ -492,7 +535,7 @@ private static Delegate GenerateFixup( Type entityType, Type relatedEntityType, INavigation navigation, - INavigation inverseNavigation) + INavigation? inverseNavigation) { var entityParameter = Parameter(entityType); var relatedEntityParameter = Parameter(relatedEntityType); @@ -573,17 +616,33 @@ private static TCollection PopulateCollection( return (TCollection)collection; } + + private static readonly MethodInfo CollectionAccessorAddMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)) ?? throw new UnreachableException(); + + private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)) ?? throw new UnreachableException(); #endregion protected abstract ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression); + #region Create expression helpers /// /// Create expression to get a property's value from JObject /// private Expression CreateGetValueExpression(IProperty property, Type? type = null) { - Debug.Assert(_concreteStructuralTypeBlock != null, "Property value retrieval can only happen inside a structural type block."); - return CreateGetValueExpression(_concreteStructuralTypeBlock.JObject, property.GetJsonPropertyName(), type ?? property.ClrType, !property.IsNullable && !property.IsKey(), property.GetTypeMapping()); + var currentJObject = _concreteStructuralTypeBlock?.JObject ?? _entityTypeBlockJObject; + Debug.Assert(currentJObject != null, "Property value can only be retrieved inside an structural type block."); + + if (property.Name == CosmosPartitionKeyInPrimaryKeyConvention.JObjectPropertyName) + { + return currentJObject; + } + + return CreateGetValueExpression(currentJObject, property.GetJsonPropertyName(), type ?? property.ClrType, !property.IsNullable && !property.IsKey(), property.GetTypeMapping()); } private Expression CreateGetValueExpression(ParameterExpression jObject, string? property, Type type, bool isNonNullableScalar, CoreTypeMapping? typeMapping) @@ -685,5 +744,6 @@ private static Expression CreateSerializeJTokenToTypeExpression(Expression jToke private static T? SafeToObjectWithSerializer(JToken? token) => token == null || token.Type == JTokenType.Null ? default : token.ToObject(CosmosClientWrapper.Serializer); + #endregion } } From 221d5a72d84b10e91ea11e1df7090446787aeb7e Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:58:15 +0100 Subject: [PATCH 04/23] Move back to old way --- ...ionBindingRemovingExpressionVisitorBase.cs | 638 ++++++++++++------ 1 file changed, 437 insertions(+), 201 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 3ee2e189916..634711b3a3d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#nullable disable + using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; @@ -13,11 +15,16 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { - // Removes assignments of MaterializationContext - // Rewrites usages of MaterializationContext to use JObject variable injected by JObjectInjectingExpressionVisitor instead. - private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase(ParameterExpression jTokenParameter, bool trackQueryResults) + private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase( + ParameterExpression jTokenParameter, + bool trackQueryResults) : ExpressionVisitor { + private static readonly MethodInfo GetItemMethodInfo + = typeof(JObject).GetRuntimeProperties() + .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) + .GetMethod; + private static readonly PropertyInfo JTokenTypePropertyInfo = typeof(JToken).GetRuntimeProperties() .Single(mi => mi.Name == nameof(JToken.Type)); @@ -26,20 +33,47 @@ private static readonly MethodInfo JTokenToObjectWithSerializerMethodInfo = typeof(JToken).GetRuntimeMethods() .Single(mi => mi.Name == nameof(JToken.ToObject) && mi.GetParameters().Length == 1 && mi.IsGenericMethodDefinition); - private static readonly MethodInfo GetItemMethodInfo - = typeof(JToken).GetRuntimeProperties() - .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object)) - .GetMethod ?? throw new UnreachableException(); + private static readonly MethodInfo CollectionAccessorAddMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)); + + private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)); + + private readonly IDictionary _materializationContextBindings + = new Dictionary(); + + private readonly IDictionary _projectionBindings + = new Dictionary(); + private readonly IDictionary _ownerMappings + = new Dictionary(); + + private readonly Dictionary _instanceTypeBaseMappings = new(); + private readonly Dictionary> _materializationContextCompexPropertyJObjectMappings = new(); + + private readonly IDictionary _ordinalParameterBindings + = new Dictionary(); + + private List _pendingIncludes = []; + private int _currentComplexIndex; private static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); - private ParameterExpression? _collectionBlockJArray; - private ParameterExpression? _entityTypeBlockJObject; - private ConcreteStructuralTypeBlock? _concreteStructuralTypeBlock; - private List _pendingIncludes = []; - private int _currentComplexIndex = 1; + protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) + { + if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression + && constantExpression.Value is ITypeBase structuralType) + { + // @TODO: Maybe use visitor? Does this work if tracking is off? + var instanceBlock = (BlockExpression)((BlockExpression)switchCaseExpression.Body).Expressions.First(x => x is BlockExpression b && b.Variables.FirstOrDefault()?.Type == structuralType.ClrType); + _instanceTypeBaseMappings.Add(instanceBlock.Variables.Single(), structuralType); + } + + return base.VisitSwitchCase(switchCaseExpression); + } protected override Expression VisitBinary(BinaryExpression binaryExpression) { @@ -47,9 +81,12 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { if (binaryExpression.Left is ParameterExpression parameterExpression) { - if (parameterExpression.Type == typeof(JObject) || - parameterExpression.Type == typeof(JArray)) + if (parameterExpression.Type == typeof(JObject) + || parameterExpression.Type == typeof(JArray)) { + string storeName = null; + + // Values injected by JObjectInjectingExpressionVisitor var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; if (projectionExpression is UnaryExpression @@ -63,14 +100,101 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) projectionExpression = operand.Operand; } - return MakeBinary(binaryExpression.NodeType, binaryExpression.Left, Visit(projectionExpression)); + switch (projectionExpression) + { + // ProjectionBindingExpression may represent a named token to be obtained from a containing JObject, or + // it may be that the token is not nested in a JObject if the query was generated using the SQL VALUE clause. + case ProjectionBindingExpression projectionBindingExpression: + { + var projection = GetProjection(projectionBindingExpression); + projectionExpression = projection.Expression; + if (!projection.IsValueProjection) + { + storeName = projection.Alias; + } + + break; + } + + case ObjectArrayAccessExpression e: + storeName = e.PropertyName; + break; + + case EntityProjectionExpression e: + storeName = e.PropertyName; + break; + } + + Expression valueExpression; + switch (projectionExpression) + { + case ObjectArrayAccessExpression objectArrayProjectionExpression: + _projectionBindings[objectArrayProjectionExpression] = parameterExpression; + valueExpression = CreateGetValueExpression( + objectArrayProjectionExpression.Object, storeName, parameterExpression.Type); + break; + + case EntityProjectionExpression entityProjectionExpression: + var accessExpression = entityProjectionExpression.Object; + _projectionBindings[accessExpression] = parameterExpression; + + switch (accessExpression) + { + case ObjectReferenceExpression: + valueExpression = CreateGetValueExpression(jTokenParameter, storeName, parameterExpression.Type); + break; + + case ObjectAccessExpression: + // Access to an owned type may be nested inside another owned type, so collect the store names + // and add owner mappings for each. + var storeNames = new List(); + while (accessExpression is ObjectAccessExpression objectAccessExpression) + { + accessExpression = objectAccessExpression.Object; + storeNames.Add(objectAccessExpression.PropertyName); + _ownerMappings[objectAccessExpression] + = (objectAccessExpression.Navigation.DeclaringEntityType, accessExpression); + } + + valueExpression = CreateGetValueExpression(accessExpression, (string)null, typeof(JObject)); + for (var i = storeNames.Count - 1; i >= 0; i--) + { + valueExpression = CreateGetValueExpression(valueExpression, storeNames[i], typeof(JObject)); + } + + break; + default: + throw new InvalidOperationException( + CoreStrings.TranslationFailed(binaryExpression.Print())); + } + + break; + + default: + throw new UnreachableException(); + } + + return MakeBinary(ExpressionType.Assign, binaryExpression.Left, valueExpression); } - // Overwrite any creations of MaterializationContext if (parameterExpression.Type == typeof(MaterializationContext)) { var newExpression = (NewExpression)binaryExpression.Right; - Debug.Assert(newExpression.Constructor != null, "Materialization assignment must always be via constructor"); + + EntityProjectionExpression entityProjectionExpression; + if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) + { + var projection = GetProjection(projectionBindingExpression); + entityProjectionExpression = (EntityProjectionExpression)projection.Expression; + } + else + { + var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; + entityProjectionExpression = (EntityProjectionExpression)projection; + } + + _materializationContextBindings[parameterExpression] = entityProjectionExpression.Object; + var updatedExpression = New( newExpression.Constructor, Constant(ValueBuffer.Empty), @@ -79,15 +203,20 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); } } - else if (binaryExpression.Left is MemberExpression memberExpression) + + if (binaryExpression.Left is MemberExpression memberExpression) { - Debug.Assert(_concreteStructuralTypeBlock != null, "Assignments to properties can only happen inside a structural type block."); - var complexProperty = _concreteStructuralTypeBlock.StructuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); - if (complexProperty != null) + if (memberExpression.Expression is ParameterExpression instanceParameterExpression && _instanceTypeBaseMappings.TryGetValue(instanceParameterExpression, out var structuralType)) { - return CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty); + var complexProperty = structuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); + if (complexProperty != null) + { + if (!complexProperty.IsCollection) + { + return CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty); + } + } } - return memberExpression.Assign(Visit(binaryExpression.Right)); } } @@ -97,29 +226,63 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, Expression valueExpression, IComplexProperty complexProperty) { - Debug.Assert(_concreteStructuralTypeBlock != null, "Complex property assignments can only happen inside a structural type block."); + var complexPropertyMaterializationContext = new ComplexPropertyMaterializationContextExtractorExpressionVisitor().Extract(valueExpression, complexProperty); + if (complexPropertyMaterializationContext == null) + { + Debug.Assert(false); + // complex property without properties... + // What if a complex property only has a list in it.... + // We might need to get the materialization context in a different way.. Just create an instance materialization context mapping? + + // Can/should we inject materializers for collections of complex types in InjectMaterializers already? How does it happen in slqserver again? We know it calls it, but when. + } + + Expression parentJObject; + if (complexProperty.DeclaringType is IComplexType parentComplexType) + { + parentJObject = _materializationContextCompexPropertyJObjectMappings[complexPropertyMaterializationContext][parentComplexType.ComplexProperty]; + } + else + { + parentJObject = _projectionBindings[_materializationContextBindings[complexPropertyMaterializationContext]]; + } var complexJObjectVariableExpression = Variable( typeof(JObject), - "complexJObject" + _currentComplexIndex++); + "complexJObject" + ++_currentComplexIndex); var assignComplexJObjectVariableExpression = Assign(complexJObjectVariableExpression, Call( // @TODO: Can we reuse get property value? ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), - Call(_concreteStructuralTypeBlock.JObject, GetItemMethodInfo, + Call(parentJObject, GetItemMethodInfo, // @TODO: Which jobject........ Constant(complexProperty.Name) ) )); - if (complexProperty.IsNullable) + if (!_materializationContextCompexPropertyJObjectMappings.TryGetValue(complexPropertyMaterializationContext, out var complexPropertyJObjectMappings)) + { + complexPropertyJObjectMappings = new(); + _materializationContextCompexPropertyJObjectMappings[complexPropertyMaterializationContext] = complexPropertyJObjectMappings; + } + + complexPropertyJObjectMappings.Add(complexProperty, complexJObjectVariableExpression); + + BlockExpression materializationBlock; + if (valueExpression is ConditionalExpression condition) { - var condition = (ConditionalExpression)valueExpression; + materializationBlock = (BlockExpression)((UnaryExpression)condition.IfFalse).Operand; valueExpression = Condition( Equal(complexJObjectVariableExpression, Constant(null)), condition.IfTrue, - condition.IfFalse); + materializationBlock); + } + else + { + materializationBlock = (BlockExpression)valueExpression; } - valueExpression = EnterScope(ref _concreteStructuralTypeBlock, new ConcreteStructuralTypeBlock(complexJObjectVariableExpression, complexProperty.ComplexType), - () => Visit(valueExpression)); + var instanceParameter = materializationBlock.Variables.First(v => v.Type == complexProperty.ComplexType.ClrType); + _instanceTypeBaseMappings.Add(instanceParameter, complexProperty.ComplexType); + + valueExpression = Visit(valueExpression); return Block( [complexJObjectVariableExpression], @@ -128,26 +291,76 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me ); } - /// - /// Overwrites usages of MaterializationContext to get property values from JObject - /// Handles IncludeExpressions to track included entities - /// + private class ComplexPropertyMaterializationContextExtractorExpressionVisitor : ExpressionVisitor + { + private IComplexProperty _complexProperty; + private ParameterExpression _materializationContext; + public ParameterExpression Extract(Expression expression, IComplexProperty complexProperty) + { + _complexProperty = complexProperty; + _materializationContext = null; + Visit(expression); + return _materializationContext; + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + var method = methodCallExpression.Method; + var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; + if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) + { + var property = methodCallExpression.Arguments[2].GetConstantValue(); + + var declaringType = property.DeclaringType; + while (declaringType is IComplexType c) + { + if (c.ComplexProperty == _complexProperty) + { + var param = (methodCallExpression.Arguments[0] as MethodCallExpression)?.Object as ParameterExpression; + if (param.Type == typeof(MaterializationContext)) + { + _materializationContext = param; + return methodCallExpression; + } + } + + declaringType = c.ComplexProperty.DeclaringType; + } + } + return base.VisitMethodCall(methodCallExpression); + } + } + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { var method = methodCallExpression.Method; var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; - - // Use jObject instead of MaterializationContext to get property values if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) { var property = methodCallExpression.Arguments[2].GetConstantValue(); + Expression innerExpression; if (methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) { var projection = GetProjection(projectionBindingExpression); - return CreateGetJTokenExpression(jTokenParameter, projection.Alias); + + innerExpression = Convert( + CreateReadJTokenExpression(jTokenParameter, projection.Alias), + typeof(JObject)); + } + else + { + var materializationContext = (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object; + if (property.DeclaringType is IComplexType complexType) + { + innerExpression = _materializationContextCompexPropertyJObjectMappings[materializationContext][complexType.ComplexProperty]; + } + else + { + innerExpression = _materializationContextBindings[materializationContext]; + } } - return CreateGetValueExpression(property, method.ReturnType); + return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); } if (method.DeclaringType == typeof(Enumerable) @@ -165,10 +378,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Navigation)); } - if (trackQueryResults) - { - - } _pendingIncludes.Add(includeExpression); Visit(includeExpression.EntityExpression); @@ -181,71 +390,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return base.VisitMethodCall(methodCallExpression); } - #region Context - protected override Expression VisitBlock(BlockExpression blockExpression) - { - var param = blockExpression.Variables.Count == 1 ? blockExpression.Variables[0] : null; - if (param?.Type == typeof(JObject)) - { - return EnterScope(ref _entityTypeBlockJObject, param, () => base.VisitBlock(blockExpression)); - } - - if (param?.Type == typeof(JArray)) - { - return EnterScope(ref _collectionBlockJArray, param, () => base.VisitBlock(blockExpression)); - } - - return base.VisitBlock(blockExpression); - } - - //protected override Expression VisitLambda(Expression node) - //{ - // if (node.Parameters.FirstOrDefault(x => x.Type == typeof(JObject)) is ParameterExpression jObject) - // { - // return EnterScope(ref _entityTypeBlockJObject, jObject, () => base.VisitLambda(node)); - // } - // return base.VisitLambda(node); - //} - - protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) - { - if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression - && constantExpression.Value is ITypeBase structuralType) - { - Debug.Assert(_entityTypeBlockJObject != null, "Concrete structural type swith case can not be outside of an entity type block."); - var jObjectVariable = _entityTypeBlockJObject; - return EnterScope(ref _concreteStructuralTypeBlock, new ConcreteStructuralTypeBlock(jObjectVariable, structuralType), () => base.VisitSwitchCase(switchCaseExpression)); - } - - return base.VisitSwitchCase(switchCaseExpression); - } - - private class ConcreteStructuralTypeBlock - { - public ConcreteStructuralTypeBlock(ParameterExpression jObject, ITypeBase structuralType) - { - JObject = jObject; - StructuralType = structuralType; - } - - public ITypeBase StructuralType { get; } - - public ParameterExpression JObject { get; } - } - - private static TReturn EnterScope(ref TScope scope, TScope newValue, Func action) - { - var oldValue = scope; -#pragma warning disable IDE0059 // Unnecessary assignment of a value - scope = newValue; -#pragma warning restore IDE0059 // Unnecessary assignment of a value - var result = action(); - scope = oldValue; - return result; - } - #endregion - - #region Include protected override Expression VisitExtension(Expression extensionExpression) { switch (extensionExpression) @@ -254,43 +398,41 @@ protected override Expression VisitExtension(Expression extensionExpression) { var projection = GetProjection(projectionBindingExpression); - return CreateGetValueExpression(jTokenParameter, + return CreateGetValueExpression( + jTokenParameter, projection.IsValueProjection ? null : projection.Alias, - typeof(JObject), - false, + projectionBindingExpression.Type, (projection.Expression as SqlExpression)?.TypeMapping); } - case ObjectArrayAccessExpression objectArrayAccessExpression: - { - return CreateGetValueExpression( - _entityTypeBlockJObject ?? throw new InvalidOperationException(), - objectArrayAccessExpression.PropertyName, - objectArrayAccessExpression.Type, - false, - null); - } - - case EntityProjectionExpression entityProjectionExpression: - { - Debug.Assert(_entityTypeBlockJObject != null, "Entity projection can only be inside an entity type block."); - return CreateGetValueExpression( - _entityTypeBlockJObject ?? throw new InvalidOperationException(), - entityProjectionExpression.PropertyName, - entityProjectionExpression.Type, - false, - null); - } - case CollectionShaperExpression collectionShaperExpression: { - Debug.Assert(collectionShaperExpression.Navigation != null); - Debug.Assert(_collectionBlockJArray != null, "Collection shaper can only be inside a collection block."); + ObjectArrayAccessExpression objectArrayAccess; + switch (collectionShaperExpression.Projection) + { + case ProjectionBindingExpression projectionBindingExpression: + var projection = GetProjection(projectionBindingExpression); + objectArrayAccess = (ObjectArrayAccessExpression)projection.Expression; + break; + case ObjectArrayAccessExpression objectArrayProjectionExpression: + objectArrayAccess = objectArrayProjectionExpression; + break; + default: + throw new InvalidOperationException(CoreStrings.TranslationFailed(extensionExpression.Print())); + } + + var jArray = _projectionBindings[objectArrayAccess]; + var jObjectParameter = Parameter(typeof(JObject), jArray.Name + "Object"); + var ordinalParameter = Parameter(typeof(int), jArray.Name + "Ordinal"); - var jObjectParameter = Parameter(typeof(JObject), _collectionBlockJArray.Name + "Object"); - var ordinalParameter = Parameter(typeof(int), _collectionBlockJArray.Name + "Ordinal"); + var accessExpression = objectArrayAccess.InnerProjection.Object; + _projectionBindings[accessExpression] = jObjectParameter; + _ownerMappings[accessExpression] = + (objectArrayAccess.Navigation.DeclaringEntityType, objectArrayAccess.Object); + _ordinalParameterBindings[accessExpression] = Add( + ordinalParameter, Constant(1, typeof(int))); - var innerShaper = EnterScope(ref _entityTypeBlockJObject, jObjectParameter, () => (BlockExpression)Visit(collectionShaperExpression.InnerShaper)); + var innerShaper = (BlockExpression)Visit(collectionShaperExpression.InnerShaper); innerShaper = AddIncludes(innerShaper); @@ -298,7 +440,7 @@ protected override Expression VisitExtension(Expression extensionExpression) EnumerableMethods.SelectWithOrdinal.MakeGenericMethod(typeof(JObject), innerShaper.Type), Call( EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), - _collectionBlockJArray), + jArray), Lambda(innerShaper, jObjectParameter, ordinalParameter)); var navigation = collectionShaperExpression.Navigation; @@ -325,27 +467,24 @@ protected override Expression VisitExtension(Expression extensionExpression) if (!isFirstInclude) { - return jObjectBlock!; + return jObjectBlock; } + Check.DebugAssert(jObjectBlock != null, "The first include must end up on a valid shaper block"); - var jObjectParameter = jObjectBlock.Variables.Single(); - return EnterScope(ref _entityTypeBlockJObject, jObjectParameter, () => - { - // These are the expressions added by JObjectInjectingExpressionVisitor - var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[^1]; + // These are the expressions added by JObjectInjectingExpressionVisitor + var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[^1]; - var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; - shaperBlock = AddIncludes(shaperBlock); + var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; + shaperBlock = AddIncludes(shaperBlock); - var jObjectExpressions = new List(jObjectBlock.Expressions); - jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); + var jObjectExpressions = new List(jObjectBlock.Expressions); + jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); - jObjectExpressions.Add( - jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); + jObjectExpressions.Add( + jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); - return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); - }); + return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); } } @@ -424,7 +563,7 @@ private void AddInclude( private static readonly MethodInfo IncludeReferenceMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeReference))!; + .GetDeclaredMethod(nameof(IncludeReference)); private static void IncludeReference( #pragma warning disable EF1001 // Internal EF Core API usage. @@ -469,7 +608,7 @@ private static void IncludeReference( private static readonly MethodInfo IncludeCollectionMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeCollection))!; + .GetDeclaredMethod(nameof(IncludeCollection)); private static void IncludeCollection( #pragma warning disable EF1001 // Internal EF Core API usage. @@ -500,7 +639,7 @@ private static void IncludeCollection( foreach (var relatedEntity in relatedEntities) { fixup(includingEntity, relatedEntity); - inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity!); + inverseNavigation?.SetIsLoadedWhenNoTracking(relatedEntity); } } else @@ -535,7 +674,7 @@ private static Delegate GenerateFixup( Type entityType, Type relatedEntityType, INavigation navigation, - INavigation? inverseNavigation) + INavigation inverseNavigation) { var entityParameter = Parameter(entityType); var relatedEntityParameter = Parameter(relatedEntityType); @@ -558,7 +697,7 @@ private static Delegate GenerateFixup( .Compile(); } - private static Delegate? GenerateInitialize( + private static Delegate GenerateInitialize( Type entityType, INavigation navigation) { @@ -598,7 +737,7 @@ private static Expression AddToCollectionNavigation( private static readonly MethodInfo PopulateCollectionMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() - .GetDeclaredMethod(nameof(PopulateCollection))!; + .GetDeclaredMethod(nameof(PopulateCollection)); private static readonly MethodInfo IsAssignableFromMethodInfo = typeof(IReadOnlyEntityType).GetMethod(nameof(IReadOnlyEntityType.IsAssignableFrom), [typeof(IReadOnlyEntityType)])!; @@ -617,37 +756,151 @@ private static TCollection PopulateCollection( return (TCollection)collection; } - private static readonly MethodInfo CollectionAccessorAddMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)) ?? throw new UnreachableException(); - - private static readonly MethodInfo CollectionAccessorGetOrCreateMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)) ?? throw new UnreachableException(); - #endregion - protected abstract ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression); - #region Create expression helpers - /// - /// Create expression to get a property's value from JObject - /// - private Expression CreateGetValueExpression(IProperty property, Type? type = null) - { - var currentJObject = _concreteStructuralTypeBlock?.JObject ?? _entityTypeBlockJObject; - Debug.Assert(currentJObject != null, "Property value can only be retrieved inside an structural type block."); + private static Expression CreateReadJTokenExpression(Expression jObjectExpression, string propertyName) + => Call(jObjectExpression, GetItemMethodInfo, Constant(propertyName)); + private Expression CreateGetValueExpression( + Expression jTokenExpression, + IProperty property, + Type type) + { if (property.Name == CosmosPartitionKeyInPrimaryKeyConvention.JObjectPropertyName) { - return currentJObject; + return _projectionBindings[jTokenExpression]; } - return CreateGetValueExpression(currentJObject, property.GetJsonPropertyName(), type ?? property.ClrType, !property.IsNullable && !property.IsKey(), property.GetTypeMapping()); + var entityType = property.DeclaringType as IEntityType; + var ownership = entityType?.FindOwnership(); + var storeName = property.GetJsonPropertyName(); + if (storeName.Length == 0) + { + if (entityType == null + || !entityType.IsDocumentRoot()) + { + if (ownership is { IsUnique: false } && property.IsOrdinalKeyProperty()) + { + var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; + if (ordinalExpression.Type != type) + { + ordinalExpression = Convert(ordinalExpression, type); + } + + return ordinalExpression; + } + + var principalProperty = property.FindFirstPrincipal(); + if (principalProperty != null) + { + Expression ownerJObjectExpression = null; + if (_ownerMappings.TryGetValue(jTokenExpression, out var ownerInfo)) + { + Check.DebugAssert( + principalProperty.DeclaringType.IsAssignableFrom(ownerInfo.EntityType), + $"{principalProperty.DeclaringType} is not assignable from {ownerInfo.EntityType}"); + + ownerJObjectExpression = ownerInfo.JObjectExpression; + } + else if (jTokenExpression is ObjectReferenceExpression objectReferenceExpression) + { + ownerJObjectExpression = objectReferenceExpression; + } + else if (jTokenExpression is ObjectAccessExpression objectAccessExpression) + { + ownerJObjectExpression = objectAccessExpression.Object; + } + + if (ownerJObjectExpression != null) + { + return CreateGetValueExpression(ownerJObjectExpression, principalProperty, type); + } + } + } + + return Default(type); + } + + // Workaround for old databases that didn't store the key property + if (ownership is { IsUnique: false } + && !entityType.IsDocumentRoot() + && property.ClrType == typeof(int) + && !property.IsForeignKey() + && property.FindContainingPrimaryKey() is { Properties.Count: > 1 } + && property.GetJsonPropertyName().Length != 0 + && !property.IsShadowProperty()) + { + var readExpression = CreateGetValueExpression( + jTokenExpression, + storeName, + type.MakeNullable(), + property.GetTypeMapping(), + isNonNullableScalar: false); + + var nonNullReadExpression = readExpression; + if (nonNullReadExpression.Type != type) + { + nonNullReadExpression = Convert(nonNullReadExpression, type); + } + + var ordinalExpression = _ordinalParameterBindings[jTokenExpression]; + if (ordinalExpression.Type != type) + { + ordinalExpression = Convert(ordinalExpression, type); + } + + return Condition( + Equal(readExpression, Constant(null, readExpression.Type)), + ordinalExpression, + nonNullReadExpression); + } + + return Convert( + CreateGetValueExpression( + jTokenExpression, + storeName, + type.MakeNullable(), + property.GetTypeMapping(), + // special case keys - we check them for null to see if the entity needs to be materialized, so we want to keep the null, rather than non-nullable default + // returning defaults is supposed to help with evolving the schema - so this doesn't concern keys anyway (they shouldn't evolve) + isNonNullableScalar: !property.IsNullable && !property.IsKey()), + type); } - private Expression CreateGetValueExpression(ParameterExpression jObject, string? property, Type type, bool isNonNullableScalar, CoreTypeMapping? typeMapping) + private Expression CreateGetValueExpression( + Expression jTokenExpression, + string storeName, + Type type, + CoreTypeMapping typeMapping = null, + bool isNonNullableScalar = false) { - var valueExpression = property != null ? CreateGetJTokenExpression(jObject, property) : jObject; + Check.DebugAssert(type.IsNullableType(), "Must read nullable type from JObject."); + + var innerExpression = jTokenExpression switch + { + _ when _projectionBindings.TryGetValue(jTokenExpression, out var innerVariable) + => innerVariable, + + ObjectReferenceExpression + => jTokenParameter, + + ObjectAccessExpression objectAccessExpression + => CreateGetValueExpression( + objectAccessExpression.Object, + ((IAccessExpression)objectAccessExpression.Object).PropertyName, + typeof(JObject)), + + _ => jTokenExpression + }; + + jTokenExpression = storeName == null + ? innerExpression + : CreateReadJTokenExpression( + innerExpression.Type == typeof(JObject) + ? innerExpression + : Convert(innerExpression, typeof(JObject)), storeName); + + Expression valueExpression; var converter = typeMapping?.Converter; if (converter != null) { @@ -699,12 +952,12 @@ var body replaceExpression, body); - valueExpression = Invoke(Lambda(body, jTokenParameter), valueExpression); + valueExpression = Invoke(Lambda(body, jTokenParameter), jTokenExpression); } else { - valueExpression = CreateSerializeJTokenToTypeExpression( - valueExpression, + valueExpression = ConvertJTokenToType( + jTokenExpression, (isNonNullableScalar ? typeMapping?.ClrType : typeMapping?.ClrType.MakeNullable()) @@ -719,31 +972,14 @@ var body return valueExpression; } - /// - /// Create expression to get the JToken for a property from JObject - /// - private Expression CreateGetJTokenExpression(ParameterExpression jObject, IPropertyBase propertyBase) - => CreateGetJTokenExpression(jObject, propertyBase is IProperty p ? p.GetJsonPropertyName() : propertyBase.Name); - - /// - /// Create expression to get the JToken for a property from JObject - /// - private Expression CreateGetJTokenExpression(ParameterExpression jObject, string propertyName) - => Call(jObject, GetItemMethodInfo, - Constant(propertyName)); - - /// - /// Create expression to serialize JToken to given type - /// - private static Expression CreateSerializeJTokenToTypeExpression(Expression jTokenExpression, Type type) + private static Expression ConvertJTokenToType(Expression jTokenExpression, Type type) => type == typeof(JToken) ? jTokenExpression : Call( ToObjectWithSerializerMethodInfo.MakeGenericMethod(type), jTokenExpression); - private static T? SafeToObjectWithSerializer(JToken? token) + private static T SafeToObjectWithSerializer(JToken token) => token == null || token.Type == JTokenType.Null ? default : token.ToObject(CosmosClientWrapper.Serializer); - #endregion } } From 814365072ec410554f605a41f70899d5691ce060 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:59:14 +0100 Subject: [PATCH 05/23] Remove todo (works) --- ...tor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 634711b3a3d..3bdb0a0f1e5 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -67,7 +67,6 @@ protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression && constantExpression.Value is ITypeBase structuralType) { - // @TODO: Maybe use visitor? Does this work if tracking is off? var instanceBlock = (BlockExpression)((BlockExpression)switchCaseExpression.Body).Expressions.First(x => x is BlockExpression b && b.Variables.FirstOrDefault()?.Type == structuralType.ClrType); _instanceTypeBaseMappings.Add(instanceBlock.Variables.Single(), structuralType); } @@ -252,7 +251,7 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me "complexJObject" + ++_currentComplexIndex); var assignComplexJObjectVariableExpression = Assign(complexJObjectVariableExpression, Call( // @TODO: Can we reuse get property value? ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), - Call(parentJObject, GetItemMethodInfo, // @TODO: Which jobject........ + Call(parentJObject, GetItemMethodInfo, Constant(complexProperty.Name) ) )); From b2d51dd645818e86d88916b7201f008f5ba7e293 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 15:41:46 +0100 Subject: [PATCH 06/23] WIP --- ...jectionBindingRemovingExpressionVisitor.cs | 3 +- ...ionBindingRemovingExpressionVisitorBase.cs | 173 +++++++++++++++--- ...sitor.JObjectInjectingExpressionVisitor.cs | 36 ++-- ...osShapedQueryCompilingExpressionVisitor.cs | 1 + 4 files changed, 168 insertions(+), 45 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs index e02ac1de670..1586ef2b521 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs @@ -8,10 +8,11 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { private sealed class CosmosProjectionBindingRemovingExpressionVisitor( + CosmosShapedQueryCompilingExpressionVisitor parentVisitor, SelectExpression selectExpression, ParameterExpression jTokenParameter, bool trackQueryResults) - : CosmosProjectionBindingRemovingExpressionVisitorBase(jTokenParameter, trackQueryResults) + : CosmosProjectionBindingRemovingExpressionVisitorBase(parentVisitor, jTokenParameter, trackQueryResults) { protected override ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression) => selectExpression.Projection[GetProjectionIndex(projectionBindingExpression)]; diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 3bdb0a0f1e5..34aaaf925cb 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -3,6 +3,10 @@ #nullable disable +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using System.Text.RegularExpressions; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; @@ -10,12 +14,14 @@ using Microsoft.EntityFrameworkCore.Query.Internal; using Newtonsoft.Json.Linq; using static System.Linq.Expressions.Expression; +using static System.Net.Mime.MediaTypeNames; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase( + CosmosShapedQueryCompilingExpressionVisitor parentVisitor, ParameterExpression jTokenParameter, bool trackQueryResults) : ExpressionVisitor @@ -50,6 +56,8 @@ private readonly IDictionary _projectionBinding private readonly IDictionary _ownerMappings = new Dictionary(); + private readonly Dictionary _entityInstanceMaterializationContextMappings = new(); + private readonly Dictionary _concreteTypeInstanceMaterializationContextMappings = new(); private readonly Dictionary _instanceTypeBaseMappings = new(); private readonly Dictionary> _materializationContextCompexPropertyJObjectMappings = new(); @@ -62,6 +70,20 @@ private static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); + protected override Expression VisitBlock(BlockExpression node) + { + var materializationContextParameter = node.Variables + .SingleOrDefault(v => v.Type == typeof(MaterializationContext)); + + if (materializationContextParameter != null) + { + var instanceParameter = node.Variables.First(x => x.Name != null && Regex.Match(x.Name, @"instance\d+").Success); + _entityInstanceMaterializationContextMappings.Add(instanceParameter, materializationContextParameter); + } + + return base.VisitBlock(node); + } + protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) { if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression @@ -185,6 +207,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { var projection = GetProjection(projectionBindingExpression); entityProjectionExpression = (EntityProjectionExpression)projection.Expression; + } + else if () + { + } else { @@ -201,6 +227,15 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); } + + if (_entityInstanceMaterializationContextMappings.TryGetValue(parameterExpression, out var instanceMaterializationContext) && binaryExpression.Right is SwitchExpression switchExpression) + { + var instances = switchExpression.Cases.Select(x => (ParameterExpression)new SwitchCaseReturnValueExtractorExpressionVisitor().Visit(x.Body)); + foreach (var instance in instances) + { + _concreteTypeInstanceMaterializationContextMappings[instance] = instanceMaterializationContext; + } + } } if (binaryExpression.Left is MemberExpression memberExpression) @@ -210,10 +245,20 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) var complexProperty = structuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); if (complexProperty != null) { - if (!complexProperty.IsCollection) + var materializationContext = _concreteTypeInstanceMaterializationContextMappings[instanceParameterExpression]; + Expression parentJObject; + if (complexProperty.DeclaringType is IComplexType parentComplexType) { - return CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty); + parentJObject = _materializationContextCompexPropertyJObjectMappings[materializationContext][parentComplexType.ComplexProperty]; } + else + { + parentJObject = _projectionBindings[_materializationContextBindings[materializationContext]]; + } + + return complexProperty.IsCollection + ? CreateComplexCollectionAssignmentBlock(memberExpression, complexProperty, materializationContext, parentJObject) + : CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty, materializationContext, parentJObject); } } return memberExpression.Assign(Visit(binaryExpression.Right)); @@ -223,70 +268,140 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return base.VisitBinary(binaryExpression); } - private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, Expression valueExpression, IComplexProperty complexProperty) + private class SwitchCaseReturnValueExtractorExpressionVisitor : ExpressionVisitor { - var complexPropertyMaterializationContext = new ComplexPropertyMaterializationContextExtractorExpressionVisitor().Extract(valueExpression, complexProperty); - if (complexPropertyMaterializationContext == null) + public Expression Extract(Expression expression) { - Debug.Assert(false); - // complex property without properties... - // What if a complex property only has a list in it.... - // We might need to get the materialization context in a different way.. Just create an instance materialization context mapping? - - // Can/should we inject materializers for collections of complex types in InjectMaterializers already? How does it happen in slqserver again? We know it calls it, but when. + return Visit(expression); } - Expression parentJObject; - if (complexProperty.DeclaringType is IComplexType parentComplexType) + [return: NotNullIfNotNull("node")] + public override Expression Visit(Expression node) { - parentJObject = _materializationContextCompexPropertyJObjectMappings[complexPropertyMaterializationContext][parentComplexType.ComplexProperty]; + if (node is not BlockExpression) + { + return node; + } + + return base.Visit(node); } - else + + protected override Expression VisitBlock(BlockExpression blockExpression) + => Visit(blockExpression.Expressions.Last()); + } + + private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) + { + var complexJArrayVariable = Variable( + typeof(JArray), + "complexJArray" + ++_currentComplexIndex); + + var assignJArrayVariable = Assign(complexJArrayVariable, + Call( + ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)), + Call(parentJObject, GetItemMethodInfo, + Constant(complexProperty.Name) + ) + ) + ); + + var tempValueBuffer = new ProjectionBindingExpression(); // VAR1 + var structuralTypeShaperExpression = new StructuralTypeShaperExpression( + complexProperty.ComplexType, + tempValueBuffer, + complexProperty.ClrType.IsNullableType()); // @TODO: Can collection items be null? + // inject Jobject..? + var jObjectParameter = Parameter(typeof(JObject), "complexArrayItem" + _currentComplexIndex); + //var assingment = JObjectInjectingExpressionVisitor.AssignJObject(structuralTypeShaperExpression, jObjectParameter); + + //var injectedJobject = Block( + + //); + + var rawMaterializeExpression = parentVisitor.InjectStructuralTypeMaterializers(structuralTypeShaperExpression); // @TODO: We could also use entityMaterializerSource directly here.. + + var oldJTokenParametr = jTokenParameter; + jTokenParameter = jObjectParameter; + var materializeExpression = Visit(rawMaterializeExpression); + jTokenParameter = oldJTokenParametr; + + // We need to inject a jObject first.. + + var select = Call( + EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType), + Call( + EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), + complexJArrayVariable), + Lambda(materializeExpression, jObjectParameter)); + + Expression populateExpression = + Call( + PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType), + Constant(complexProperty.GetCollectionAccessor()), + select + ); + + if (complexProperty.IsNullable) { - parentJObject = _projectionBindings[_materializationContextBindings[complexPropertyMaterializationContext]]; + populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)), + Default(complexProperty.ClrType), + populateExpression + ); } - var complexJObjectVariableExpression = Variable( + + return Block( + [complexJArrayVariable], + [ + assignJArrayVariable, + memberExpression.Assign(populateExpression) + ] + ); + } + + private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, Expression materializationExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) + { + var complexJObjectVariable = Variable( typeof(JObject), "complexJObject" + ++_currentComplexIndex); - var assignComplexJObjectVariableExpression = Assign(complexJObjectVariableExpression, Call( // @TODO: Can we reuse get property value? + var assignComplexJObjectVariable = Assign(complexJObjectVariable, Call( // @TODO: Can we reuse get property value? ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), Call(parentJObject, GetItemMethodInfo, Constant(complexProperty.Name) ) )); - if (!_materializationContextCompexPropertyJObjectMappings.TryGetValue(complexPropertyMaterializationContext, out var complexPropertyJObjectMappings)) + if (!_materializationContextCompexPropertyJObjectMappings.TryGetValue(materializationContext, out var complexPropertyJObjectMappings)) { complexPropertyJObjectMappings = new(); - _materializationContextCompexPropertyJObjectMappings[complexPropertyMaterializationContext] = complexPropertyJObjectMappings; + _materializationContextCompexPropertyJObjectMappings[materializationContext] = complexPropertyJObjectMappings; } - complexPropertyJObjectMappings.Add(complexProperty, complexJObjectVariableExpression); + complexPropertyJObjectMappings.Add(complexProperty, complexJObjectVariable); BlockExpression materializationBlock; - if (valueExpression is ConditionalExpression condition) + if (materializationExpression is ConditionalExpression condition) { materializationBlock = (BlockExpression)((UnaryExpression)condition.IfFalse).Operand; - valueExpression = Condition( - Equal(complexJObjectVariableExpression, Constant(null)), + materializationExpression = Condition( + Equal(complexJObjectVariable, Constant(null)), condition.IfTrue, materializationBlock); } else { - materializationBlock = (BlockExpression)valueExpression; + materializationBlock = (BlockExpression)materializationExpression; } var instanceParameter = materializationBlock.Variables.First(v => v.Type == complexProperty.ComplexType.ClrType); _instanceTypeBaseMappings.Add(instanceParameter, complexProperty.ComplexType); - valueExpression = Visit(valueExpression); + materializationExpression = Visit(materializationExpression); return Block( - [complexJObjectVariableExpression], - assignComplexJObjectVariableExpression, - memberExpression.Assign(valueExpression) + [complexJObjectVariable], + assignComplexJObjectVariable, + memberExpression.Assign(materializationExpression) ); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs index 3bd4365acb5..c06f2caf42b 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs @@ -15,22 +15,11 @@ private sealed class JObjectInjectingExpressionVisitor : ExpressionVisitor { private int _currentEntityIndex; - protected override Expression VisitExtension(Expression extensionExpression) + public static List AssignJObject(StructuralTypeShaperExpression shaperExpression, ParameterExpression jObjectVariable) { - switch (extensionExpression) - { - case StructuralTypeShaperExpression shaperExpression: - { - _currentEntityIndex++; - - var valueBufferExpression = shaperExpression.ValueBufferExpression; + var valueBufferExpression = shaperExpression.ValueBufferExpression; - var jObjectVariable = Variable( - typeof(JObject), - "jObject" + _currentEntityIndex); - var variables = new List { jObjectVariable }; - - var expressions = new List + var expressions = new List { Assign( jObjectVariable, @@ -43,9 +32,26 @@ protected override Expression VisitExtension(Expression extensionExpression) shaperExpression) }; + return expressions; + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case StructuralTypeShaperExpression shaperExpression: + { + _currentEntityIndex++; + + var jObjectVariable = Variable( + typeof(JObject), + "jObject" + _currentEntityIndex); + + var expressions = AssignJObject(shaperExpression, jObjectVariable); + return Block( shaperExpression.Type, - variables, + [jObjectVariable], expressions); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index f79539eb0b7..072401c9005 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -67,6 +67,7 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery } shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor( + this, selectExpression, jTokenParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) .Visit(shaperBody); From 1a116bc847e917bc114b94fb86cde4f3d5d6b86e Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 16:53:08 +0100 Subject: [PATCH 07/23] WIP: Firsts things working --- ...ionBindingRemovingExpressionVisitorBase.cs | 88 ++++++++++++------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 34aaaf925cb..78d28957423 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -7,11 +7,14 @@ using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; using System.Text.RegularExpressions; +using Microsoft.Azure.Cosmos.Linq; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.EntityFrameworkCore.Storage; using Newtonsoft.Json.Linq; using static System.Linq.Expressions.Expression; using static System.Net.Mime.MediaTypeNames; @@ -89,7 +92,9 @@ protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression && constantExpression.Value is ITypeBase structuralType) { - var instanceBlock = (BlockExpression)((BlockExpression)switchCaseExpression.Body).Expressions.First(x => x is BlockExpression b && b.Variables.FirstOrDefault()?.Type == structuralType.ClrType); + var instanceBlock = ((BlockExpression)switchCaseExpression.Body).Expressions + .Select(x => x as BlockExpression ?? ((x as ConditionalExpression)?.IfFalse as UnaryExpression)?.Operand as BlockExpression) + .First(x => x?.Variables.FirstOrDefault()?.Type == structuralType.ClrType); _instanceTypeBaseMappings.Add(instanceBlock.Variables.Single(), structuralType); } @@ -202,23 +207,29 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { var newExpression = (NewExpression)binaryExpression.Right; - EntityProjectionExpression entityProjectionExpression; - if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) + if (newExpression.Arguments[0] is ComplexPropertyValueBufferExpression temp) { - var projection = GetProjection(projectionBindingExpression); - entityProjectionExpression = (EntityProjectionExpression)projection.Expression; - } - else if () - { - + _materializationContextBindings[parameterExpression] = temp; + _projectionBindings[temp] = jTokenParameter; + Debug.Assert(!_materializationContextCompexPropertyJObjectMappings.ContainsKey(parameterExpression), "Should never overwrite"); + _materializationContextCompexPropertyJObjectMappings[parameterExpression] = new() { { temp.ComplexProperty, jTokenParameter } }; } else { - var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; - entityProjectionExpression = (EntityProjectionExpression)projection; - } + EntityProjectionExpression entityProjectionExpression; + if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) + { + var projection = GetProjection(projectionBindingExpression); + entityProjectionExpression = (EntityProjectionExpression)projection.Expression; + } + else + { + var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; + entityProjectionExpression = (EntityProjectionExpression)projection; + } - _materializationContextBindings[parameterExpression] = entityProjectionExpression.Object; + _materializationContextBindings[parameterExpression] = entityProjectionExpression.Object; + } var updatedExpression = New( newExpression.Constructor, @@ -230,9 +241,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (_entityInstanceMaterializationContextMappings.TryGetValue(parameterExpression, out var instanceMaterializationContext) && binaryExpression.Right is SwitchExpression switchExpression) { - var instances = switchExpression.Cases.Select(x => (ParameterExpression)new SwitchCaseReturnValueExtractorExpressionVisitor().Visit(x.Body)); + var instances = switchExpression.Cases.Select(x => (ParameterExpression)new SwitchCaseReturnValueExtractorExpressionVisitor(parameterExpression.Type).Extract(x.Body)); foreach (var instance in instances) { + Debug.Assert(instance != null); _concreteTypeInstanceMaterializationContextMappings[instance] = instanceMaterializationContext; } } @@ -268,26 +280,41 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return base.VisitBinary(binaryExpression); } - private class SwitchCaseReturnValueExtractorExpressionVisitor : ExpressionVisitor + private class SwitchCaseReturnValueExtractorExpressionVisitor(Type clrType) : ExpressionVisitor { + private ParameterExpression _result; + public Expression Extract(Expression expression) { - return Visit(expression); + _result = null; + Visit(expression); + return _result; } - [return: NotNullIfNotNull("node")] - public override Expression Visit(Expression node) + protected override Expression VisitBinary(BinaryExpression node) { - if (node is not BlockExpression) + if (node.NodeType == ExpressionType.Assign && node.Type.IsAssignableTo(clrType)) { + _result = (ParameterExpression)node.Left; return node; } - return base.Visit(node); + return base.VisitBinary(node); } + } + + private class ComplexPropertyValueBufferExpression : Expression + { + public ComplexPropertyValueBufferExpression(IComplexProperty complexProperty) + { + ComplexProperty = complexProperty; + } + + public override Type Type => typeof(ValueBuffer); - protected override Expression VisitBlock(BlockExpression blockExpression) - => Visit(blockExpression.Expressions.Last()); + public override ExpressionType NodeType => ExpressionType.Extension; + + public IComplexProperty ComplexProperty { get; } } private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) @@ -305,21 +332,15 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression ) ); - var tempValueBuffer = new ProjectionBindingExpression(); // VAR1 + var tempValueBuffer = new ComplexPropertyValueBufferExpression(complexProperty); var structuralTypeShaperExpression = new StructuralTypeShaperExpression( complexProperty.ComplexType, tempValueBuffer, complexProperty.ClrType.IsNullableType()); // @TODO: Can collection items be null? - // inject Jobject..? - var jObjectParameter = Parameter(typeof(JObject), "complexArrayItem" + _currentComplexIndex); - //var assingment = JObjectInjectingExpressionVisitor.AssignJObject(structuralTypeShaperExpression, jObjectParameter); - - //var injectedJobject = Block( - - //); - + var rawMaterializeExpression = parentVisitor.InjectStructuralTypeMaterializers(structuralTypeShaperExpression); // @TODO: We could also use entityMaterializerSource directly here.. + var jObjectParameter = Parameter(typeof(JObject), "complexArrayItem" + _currentComplexIndex); var oldJTokenParametr = jTokenParameter; jTokenParameter = jObjectParameter; var materializeExpression = Visit(rawMaterializeExpression); @@ -382,11 +403,11 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me BlockExpression materializationBlock; if (materializationExpression is ConditionalExpression condition) { - materializationBlock = (BlockExpression)((UnaryExpression)condition.IfFalse).Operand; + materializationBlock = (condition.IfFalse as BlockExpression ?? (BlockExpression)(condition.IfFalse as UnaryExpression).Operand); materializationExpression = Condition( Equal(complexJObjectVariable, Constant(null)), condition.IfTrue, - materializationBlock); + Convert(materializationBlock, condition.Type)); } else { @@ -395,6 +416,7 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me var instanceParameter = materializationBlock.Variables.First(v => v.Type == complexProperty.ComplexType.ClrType); _instanceTypeBaseMappings.Add(instanceParameter, complexProperty.ComplexType); + _concreteTypeInstanceMaterializationContextMappings.Add(instanceParameter, materializationContext); materializationExpression = Visit(materializationExpression); From e5dcfe1f07f160da4fee628bed73d734cf44fac2 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:27:42 +0100 Subject: [PATCH 08/23] Add skip to projection based tests --- .../ComplexPropertiesProjectionCosmosTest.cs | 177 +++++++++++------- 1 file changed, 105 insertions(+), 72 deletions(-) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs index fc9d8af5e10..3184278f2c3 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Xunit.Sdk; + namespace Microsoft.EntityFrameworkCore.Query.Associations.ComplexProperties; public class ComplexPropertiesProjectionCosmosTest : ComplexPropertiesProjectionTestBase @@ -51,6 +53,7 @@ FROM root c """); } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task Select_value_type_property_on_null_associate_throws(QueryTrackingBehavior queryTrackingBehavior) { // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the @@ -67,6 +70,7 @@ FROM root c """); } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task Select_nullable_value_type_property_on_null_associate(QueryTrackingBehavior queryTrackingBehavior) { // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the @@ -91,86 +95,95 @@ public override async Task Select_associate(QueryTrackingBehavior queryTrackingB { await base.Select_associate(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } public override async Task Select_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_optional_associate(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } - public override async Task Select_required_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_required_nested_on_required_associate(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } - public override async Task Select_optional_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_optional_nested_on_required_associate(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } public override async Task Select_required_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.Select_required_nested_on_optional_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.Select_required_nested_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } - public override async Task Select_optional_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.Select_optional_nested_on_optional_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) - { - AssertSql( - """ + await base.Select_optional_nested_on_optional_associate(queryTrackingBehavior); + + if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + { + AssertSql( + """ SELECT VALUE c FROM root c """); - } } } @@ -182,16 +195,19 @@ public override async Task Select_unmapped_associate_scalar_property(QueryTracki { await base.Select_unmapped_associate_scalar_property(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c """); - } } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task Select_untranslatable_method_on_associate_scalar_property(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_untranslatable_method_on_associate_scalar_property(queryTrackingBehavior); @@ -211,91 +227,105 @@ public override async Task Select_associate_collection(QueryTrackingBehavior que { await base.Select_associate_collection(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - AssertSql( - """ + throw SkipException.ForSkip("Complex type tracking not supported."); + } + + AssertSql( + """ SELECT VALUE c FROM root c ORDER BY c["Id"] """); - } } public override async Task Select_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.Select_nested_collection_on_required_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.Select_nested_collection_on_required_associate(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE c FROM root c ORDER BY c["Id"] """); - } } public override async Task Select_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.Select_nested_collection_on_optional_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.Select_nested_collection_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE c FROM root c ORDER BY c["Id"] """); - } } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_associate_collection(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.SelectMany_associate_collection(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.SelectMany_associate_collection(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE a FROM root c JOIN a IN c["AssociateCollection"] """); - } } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await base.SelectMany_nested_collection_on_required_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.SelectMany_nested_collection_on_required_associate(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE n FROM root c JOIN n IN c["RequiredAssociate"]["NestedCollection"] """); - } } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - // The given key 'n' was not present in the dictionary - await base.SelectMany_nested_collection_on_optional_associate(queryTrackingBehavior); + throw SkipException.ForSkip("Complex type tracking not supported."); + } - AssertSql( - """ + await base.SelectMany_nested_collection_on_optional_associate(queryTrackingBehavior); + + AssertSql( + """ SELECT VALUE n FROM root c JOIN n IN c["OptionalAssociate"]["NestedCollection"] """); - } } #endregion Structural collection properties @@ -319,24 +349,27 @@ FROM root c public override async Task Select_subquery_required_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + throw SkipException.ForSkip("Complex type tracking not supported."); } + + await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); } public override async Task Select_subquery_optional_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) + if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + throw SkipException.ForSkip("Complex type tracking not supported."); } - } - #endregion Subquery + await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + } - #region Value types +#endregion Subquery +#region Value types public override async Task Select_root_with_value_types(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_root_with_value_types(queryTrackingBehavior); From 81428b1d2b32bd3949b837338f773e797f4dd82b Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:05:24 +0100 Subject: [PATCH 09/23] Fix constructor binding cases --- ...ionBindingRemovingExpressionVisitorBase.cs | 38 ++++++++++++------- .../ComplexPropertiesProjectionCosmosTest.cs | 2 + 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 78d28957423..90dc2717287 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -3,21 +3,14 @@ #nullable disable -using System; -using System.Diagnostics.CodeAnalysis; -using System.Linq.Expressions; using System.Text.RegularExpressions; -using Microsoft.Azure.Cosmos.Linq; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; -using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; -using Microsoft.EntityFrameworkCore.Storage; using Newtonsoft.Json.Linq; using static System.Linq.Expressions.Expression; -using static System.Net.Mime.MediaTypeNames; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; @@ -92,10 +85,30 @@ protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression && constantExpression.Value is ITypeBase structuralType) { - var instanceBlock = ((BlockExpression)switchCaseExpression.Body).Expressions - .Select(x => x as BlockExpression ?? ((x as ConditionalExpression)?.IfFalse as UnaryExpression)?.Operand as BlockExpression) - .First(x => x?.Variables.FirstOrDefault()?.Type == structuralType.ClrType); - _instanceTypeBaseMappings.Add(instanceBlock.Variables.Single(), structuralType); + var instanceVariable = ((BlockExpression)switchCaseExpression.Body).Expressions + .Select(node => + { + if (node is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert) + { + node = unaryExpression.Operand; + } + + if (node is ConditionalExpression conditionalExpression) + { + node = conditionalExpression.IfFalse; + } + + return node as BlockExpression; + }) + .Select(x => x?.Variables.FirstOrDefault(x => x.Type == structuralType.ClrType)) + .FirstOrDefault(x => x != null); + + // Can be null for constructor bindings, but since complex properties aren't supported, we don't need to set the _instanceTypeBaseMappings + if (instanceVariable != null) + { + _instanceTypeBaseMappings.Add(instanceVariable, structuralType); + } } return base.VisitSwitchCase(switchCaseExpression); @@ -241,10 +254,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (_entityInstanceMaterializationContextMappings.TryGetValue(parameterExpression, out var instanceMaterializationContext) && binaryExpression.Right is SwitchExpression switchExpression) { - var instances = switchExpression.Cases.Select(x => (ParameterExpression)new SwitchCaseReturnValueExtractorExpressionVisitor(parameterExpression.Type).Extract(x.Body)); + var instances = switchExpression.Cases.Select(x => new SwitchCaseReturnValueExtractorExpressionVisitor(parameterExpression.Type).Extract(x.Body) as ParameterExpression).Where(x => x != null);// Null for constructor binding, but since complex properties aren't supported, we don't need to set the _concreteTypeInstanceMaterializationContextMapping foreach (var instance in instances) { - Debug.Assert(instance != null); _concreteTypeInstanceMaterializationContextMappings[instance] = instanceMaterializationContext; } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs index 3184278f2c3..6f50a954831 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -26,6 +26,7 @@ FROM root c #region Scalar properties + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task Select_scalar_property_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_scalar_property_on_required_associate(queryTrackingBehavior); @@ -37,6 +38,7 @@ FROM root c """); } + [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task Select_property_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { // When OptionalAssociate is null, the property access on it evaluates to undefined in Cosmos, causing the From 0e45c8645e92946f4944a9d441abf3fc3fa7e6f7 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:50:48 +0100 Subject: [PATCH 10/23] Cleanup --- ...ionBindingRemovingExpressionVisitorBase.cs | 121 ++++++------------ 1 file changed, 42 insertions(+), 79 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 90dc2717287..d9dbaaa83fc 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -97,6 +97,11 @@ protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) if (node is ConditionalExpression conditionalExpression) { node = conditionalExpression.IfFalse; + if (node is UnaryExpression unaryExpression2 + && unaryExpression2.NodeType == ExpressionType.Convert) + { + node = unaryExpression2.Operand; + } } return node as BlockExpression; @@ -292,43 +297,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return base.VisitBinary(binaryExpression); } - private class SwitchCaseReturnValueExtractorExpressionVisitor(Type clrType) : ExpressionVisitor - { - private ParameterExpression _result; - - public Expression Extract(Expression expression) - { - _result = null; - Visit(expression); - return _result; - } - - protected override Expression VisitBinary(BinaryExpression node) - { - if (node.NodeType == ExpressionType.Assign && node.Type.IsAssignableTo(clrType)) - { - _result = (ParameterExpression)node.Left; - return node; - } - - return base.VisitBinary(node); - } - } - - private class ComplexPropertyValueBufferExpression : Expression - { - public ComplexPropertyValueBufferExpression(IComplexProperty complexProperty) - { - ComplexProperty = complexProperty; - } - - public override Type Type => typeof(ValueBuffer); - - public override ExpressionType NodeType => ExpressionType.Extension; - - public IComplexProperty ComplexProperty { get; } - } - private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) { var complexJArrayVariable = Variable( @@ -358,8 +326,6 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression var materializeExpression = Visit(rawMaterializeExpression); jTokenParameter = oldJTokenParametr; - // We need to inject a jObject first.. - var select = Call( EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType), Call( @@ -439,46 +405,6 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me ); } - private class ComplexPropertyMaterializationContextExtractorExpressionVisitor : ExpressionVisitor - { - private IComplexProperty _complexProperty; - private ParameterExpression _materializationContext; - public ParameterExpression Extract(Expression expression, IComplexProperty complexProperty) - { - _complexProperty = complexProperty; - _materializationContext = null; - Visit(expression); - return _materializationContext; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - var method = methodCallExpression.Method; - var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; - if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) - { - var property = methodCallExpression.Arguments[2].GetConstantValue(); - - var declaringType = property.DeclaringType; - while (declaringType is IComplexType c) - { - if (c.ComplexProperty == _complexProperty) - { - var param = (methodCallExpression.Arguments[0] as MethodCallExpression)?.Object as ParameterExpression; - if (param.Type == typeof(MaterializationContext)) - { - _materializationContext = param; - return methodCallExpression; - } - } - - declaringType = c.ComplexProperty.DeclaringType; - } - } - return base.VisitMethodCall(methodCallExpression); - } - } - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { var method = methodCallExpression.Method; @@ -1129,5 +1055,42 @@ private static Expression ConvertJTokenToType(Expression jTokenExpression, Type private static T SafeToObjectWithSerializer(JToken token) => token == null || token.Type == JTokenType.Null ? default : token.ToObject(CosmosClientWrapper.Serializer); + + private sealed class SwitchCaseReturnValueExtractorExpressionVisitor(Type clrType) : ExpressionVisitor + { + private ParameterExpression _result; + + public Expression Extract(Expression expression) + { + _result = null; + Visit(expression); + return _result; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + if (node.NodeType == ExpressionType.Assign && node.Type.IsAssignableTo(clrType)) + { + _result = (ParameterExpression)node.Left; + return node; + } + + return base.VisitBinary(node); + } + } + + private sealed class ComplexPropertyValueBufferExpression : Expression + { + public ComplexPropertyValueBufferExpression(IComplexProperty complexProperty) + { + ComplexProperty = complexProperty; + } + + public override Type Type => typeof(ValueBuffer); + + public override ExpressionType NodeType => ExpressionType.Extension; + + public IComplexProperty ComplexProperty { get; } + } } } From cdf186197a5c12239ae3536d895db8442b5352ba Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:32:45 +0100 Subject: [PATCH 11/23] Cleanup --- ...sitor.JObjectInjectingExpressionVisitor.cs | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs index c06f2caf42b..3bd4365acb5 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs @@ -15,11 +15,22 @@ private sealed class JObjectInjectingExpressionVisitor : ExpressionVisitor { private int _currentEntityIndex; - public static List AssignJObject(StructuralTypeShaperExpression shaperExpression, ParameterExpression jObjectVariable) + protected override Expression VisitExtension(Expression extensionExpression) { - var valueBufferExpression = shaperExpression.ValueBufferExpression; + switch (extensionExpression) + { + case StructuralTypeShaperExpression shaperExpression: + { + _currentEntityIndex++; + + var valueBufferExpression = shaperExpression.ValueBufferExpression; - var expressions = new List + var jObjectVariable = Variable( + typeof(JObject), + "jObject" + _currentEntityIndex); + var variables = new List { jObjectVariable }; + + var expressions = new List { Assign( jObjectVariable, @@ -32,26 +43,9 @@ public static List AssignJObject(StructuralTypeShaperExpression shap shaperExpression) }; - return expressions; - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - switch (extensionExpression) - { - case StructuralTypeShaperExpression shaperExpression: - { - _currentEntityIndex++; - - var jObjectVariable = Variable( - typeof(JObject), - "jObject" + _currentEntityIndex); - - var expressions = AssignJObject(shaperExpression, jObjectVariable); - return Block( shaperExpression.Type, - [jObjectVariable], + variables, expressions); } From b69203bef052795b20176cb0036784e18f4702ea Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:32:51 +0100 Subject: [PATCH 12/23] Remove todo in tests --- .../CosmosComplexTypesTrackingTest.cs | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/test/EFCore.Cosmos.FunctionalTests/CosmosComplexTypesTrackingTest.cs b/test/EFCore.Cosmos.FunctionalTests/CosmosComplexTypesTrackingTest.cs index 6362ece9712..cc82de50f03 100644 --- a/test/EFCore.Cosmos.FunctionalTests/CosmosComplexTypesTrackingTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/CosmosComplexTypesTrackingTest.cs @@ -21,11 +21,10 @@ public async Task Can_reorder_complex_collection_elements() var last = pub.Activities.Last(); await context.SaveChangesAsync(); - // TODO: Can be asserted after binding has been implemented. - //await using var assertContext = CreateContext(); - //var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); - //Assert.Equivalent(first, dbPub.Activities[0]); - //Assert.Equivalent(last, dbPub.Activities.Last()); + await using var assertContext = CreateContext(); + var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); + Assert.Equivalent(first, dbPub.Activities[0]); + Assert.Equivalent(last, dbPub.Activities.Last()); } [ConditionalFact] @@ -39,10 +38,9 @@ public async Task Can_change_complex_collection_element() pub.Activities[0].Name = "Changed123"; await context.SaveChangesAsync(); - // TODO: Can be asserted after binding has been implemented. - //await using var assertContext = CreateContext(); - //var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); - //Assert.Equivalent("Changed123", dbPub.Activities[0].Name); + await using var assertContext = CreateContext(); + var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); + Assert.Equivalent("Changed123", dbPub.Activities[0].Name); } [ConditionalFact] @@ -56,11 +54,10 @@ public async Task Can_add_complex_collection_element() pub.Activities.Add(new ActivityWithCollection { Name = "NewActivity" }); await context.SaveChangesAsync(); - // TODO: Can be asserted after binding has been implemented. - //await using var assertContext = CreateContext(); - //var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); - //Assert.Equivalent("NewActivity", dbPub.Activities.Last().Name); - //Assert.Equivalent(pub.Activities.Count, dbPub.Activities.Count); + await using var assertContext = CreateContext(); + var dbPub = await assertContext.Set().FirstAsync(x => x.Id == pub.Id); + Assert.Equivalent("NewActivity", dbPub.Activities.Last().Name); + Assert.Equivalent(pub.Activities.Count, dbPub.Activities.Count); } [ConditionalFact] From b66d91375481f3f867e43fe1062ecc20f85ca0a2 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:39:28 +0100 Subject: [PATCH 13/23] Cleanup --- ...tor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index d9dbaaa83fc..ad6d5795de8 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -348,7 +348,6 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression ); } - return Block( [complexJArrayVariable], [ @@ -363,7 +362,7 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me var complexJObjectVariable = Variable( typeof(JObject), "complexJObject" + ++_currentComplexIndex); - var assignComplexJObjectVariable = Assign(complexJObjectVariable, Call( // @TODO: Can we reuse get property value? + var assignComplexJObjectVariable = Assign(complexJObjectVariable, Call( ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), Call(parentJObject, GetItemMethodInfo, Constant(complexProperty.Name) From 699da560d5b1913c0710134b5f7b1ab0134b987a Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:50:45 +0100 Subject: [PATCH 14/23] Move to CosmosStructuralTypeMaterializerSource --- .../CosmosServiceCollectionExtensions.cs | 2 + ...jectionBindingRemovingExpressionVisitor.cs | 3 +- ...ionBindingRemovingExpressionVisitorBase.cs | 288 ++---------------- ...osShapedQueryCompilingExpressionVisitor.cs | 146 ++++++++- .../CosmosStructuralTypeMaterializerSource.cs | 23 ++ 5 files changed, 190 insertions(+), 272 deletions(-) create mode 100644 src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs diff --git a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs index 09018b0f2a5..93d0f28ec50 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs @@ -10,6 +10,7 @@ using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; using Microsoft.EntityFrameworkCore.Cosmos.ValueGeneration.Internal; using Microsoft.EntityFrameworkCore.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection; @@ -94,6 +95,7 @@ public static IServiceCollection AddCosmos( public static IServiceCollection AddEntityFrameworkCosmos(this IServiceCollection serviceCollection) { var builder = new EntityFrameworkServicesBuilder(serviceCollection) + .TryAdd() .TryAdd() .TryAdd>() .TryAdd() diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs index 1586ef2b521..e02ac1de670 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs @@ -8,11 +8,10 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { private sealed class CosmosProjectionBindingRemovingExpressionVisitor( - CosmosShapedQueryCompilingExpressionVisitor parentVisitor, SelectExpression selectExpression, ParameterExpression jTokenParameter, bool trackQueryResults) - : CosmosProjectionBindingRemovingExpressionVisitorBase(parentVisitor, jTokenParameter, trackQueryResults) + : CosmosProjectionBindingRemovingExpressionVisitorBase(jTokenParameter, trackQueryResults) { protected override ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression) => selectExpression.Projection[GetProjectionIndex(projectionBindingExpression)]; diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index ad6d5795de8..7874877ef2d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -3,7 +3,6 @@ #nullable disable -using System.Text.RegularExpressions; using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; @@ -17,14 +16,13 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public partial class CosmosShapedQueryCompilingExpressionVisitor { private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase( - CosmosShapedQueryCompilingExpressionVisitor parentVisitor, ParameterExpression jTokenParameter, bool trackQueryResults) : ExpressionVisitor { - private static readonly MethodInfo GetItemMethodInfo - = typeof(JObject).GetRuntimeProperties() - .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) + public static readonly MethodInfo GetItemMethodInfo + = typeof(JToken).GetRuntimeProperties() + .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object)) .GetMethod; private static readonly PropertyInfo JTokenTypePropertyInfo @@ -52,73 +50,16 @@ private readonly IDictionary _projectionBinding private readonly IDictionary _ownerMappings = new Dictionary(); - private readonly Dictionary _entityInstanceMaterializationContextMappings = new(); - private readonly Dictionary _concreteTypeInstanceMaterializationContextMappings = new(); - private readonly Dictionary _instanceTypeBaseMappings = new(); - private readonly Dictionary> _materializationContextCompexPropertyJObjectMappings = new(); - private readonly IDictionary _ordinalParameterBindings = new Dictionary(); - private List _pendingIncludes = []; - private int _currentComplexIndex; - private static readonly MethodInfo ToObjectWithSerializerMethodInfo + private List _pendingIncludes + = []; + + public static readonly MethodInfo ToObjectWithSerializerMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer)); - protected override Expression VisitBlock(BlockExpression node) - { - var materializationContextParameter = node.Variables - .SingleOrDefault(v => v.Type == typeof(MaterializationContext)); - - if (materializationContextParameter != null) - { - var instanceParameter = node.Variables.First(x => x.Name != null && Regex.Match(x.Name, @"instance\d+").Success); - _entityInstanceMaterializationContextMappings.Add(instanceParameter, materializationContextParameter); - } - - return base.VisitBlock(node); - } - - protected override SwitchCase VisitSwitchCase(SwitchCase switchCaseExpression) - { - if (switchCaseExpression.TestValues.SingleOrDefault() is ConstantExpression constantExpression - && constantExpression.Value is ITypeBase structuralType) - { - var instanceVariable = ((BlockExpression)switchCaseExpression.Body).Expressions - .Select(node => - { - if (node is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - node = unaryExpression.Operand; - } - - if (node is ConditionalExpression conditionalExpression) - { - node = conditionalExpression.IfFalse; - if (node is UnaryExpression unaryExpression2 - && unaryExpression2.NodeType == ExpressionType.Convert) - { - node = unaryExpression2.Operand; - } - } - - return node as BlockExpression; - }) - .Select(x => x?.Variables.FirstOrDefault(x => x.Type == structuralType.ClrType)) - .FirstOrDefault(x => x != null); - - // Can be null for constructor bindings, but since complex properties aren't supported, we don't need to set the _instanceTypeBaseMappings - if (instanceVariable != null) - { - _instanceTypeBaseMappings.Add(instanceVariable, structuralType); - } - } - - return base.VisitSwitchCase(switchCaseExpression); - } - protected override Expression VisitBinary(BinaryExpression binaryExpression) { if (binaryExpression.NodeType == ExpressionType.Assign) @@ -131,18 +72,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) string storeName = null; // Values injected by JObjectInjectingExpressionVisitor - var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; - - if (projectionExpression is UnaryExpression - { - NodeType: ExpressionType.Convert, - Operand: UnaryExpression operand - }) - { - // Unwrap EntityProjectionExpression when the root entity is not projected - // That is, this is handling the projection of a non-root entity type. - projectionExpression = operand.Operand; - } + var projectionExpression = binaryExpression.Right.UnwrapTypeConversion(out _); switch (projectionExpression) { @@ -213,7 +143,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) } break; - + case MethodCallExpression jObjectMethodCallExpression + when jObjectMethodCallExpression.Method.IsGenericMethod && jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo: + // jobject already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages jobject correctly. + return binaryExpression; default: throw new UnreachableException(); } @@ -225,12 +158,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { var newExpression = (NewExpression)binaryExpression.Right; - if (newExpression.Arguments[0] is ComplexPropertyValueBufferExpression temp) + if (newExpression.Arguments[0] is ComplexPropertyBindingExpression complexPropertyBindingExpression) { - _materializationContextBindings[parameterExpression] = temp; - _projectionBindings[temp] = jTokenParameter; - Debug.Assert(!_materializationContextCompexPropertyJObjectMappings.ContainsKey(parameterExpression), "Should never overwrite"); - _materializationContextCompexPropertyJObjectMappings[parameterExpression] = new() { { temp.ComplexProperty, jTokenParameter } }; + _materializationContextBindings[parameterExpression] = complexPropertyBindingExpression; + _projectionBindings[complexPropertyBindingExpression] = complexPropertyBindingExpression.JObjectParameter; } else { @@ -256,40 +187,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); } - - if (_entityInstanceMaterializationContextMappings.TryGetValue(parameterExpression, out var instanceMaterializationContext) && binaryExpression.Right is SwitchExpression switchExpression) - { - var instances = switchExpression.Cases.Select(x => new SwitchCaseReturnValueExtractorExpressionVisitor(parameterExpression.Type).Extract(x.Body) as ParameterExpression).Where(x => x != null);// Null for constructor binding, but since complex properties aren't supported, we don't need to set the _concreteTypeInstanceMaterializationContextMapping - foreach (var instance in instances) - { - _concreteTypeInstanceMaterializationContextMappings[instance] = instanceMaterializationContext; - } - } } - if (binaryExpression.Left is MemberExpression memberExpression) + if (binaryExpression.Left is MemberExpression { Member: FieldInfo { IsInitOnly: true } } memberExpression) { - if (memberExpression.Expression is ParameterExpression instanceParameterExpression && _instanceTypeBaseMappings.TryGetValue(instanceParameterExpression, out var structuralType)) - { - var complexProperty = structuralType.GetComplexProperties().FirstOrDefault(x => x.GetMemberInfo(true, true) == memberExpression.Member); - if (complexProperty != null) - { - var materializationContext = _concreteTypeInstanceMaterializationContextMappings[instanceParameterExpression]; - Expression parentJObject; - if (complexProperty.DeclaringType is IComplexType parentComplexType) - { - parentJObject = _materializationContextCompexPropertyJObjectMappings[materializationContext][parentComplexType.ComplexProperty]; - } - else - { - parentJObject = _projectionBindings[_materializationContextBindings[materializationContext]]; - } - - return complexProperty.IsCollection - ? CreateComplexCollectionAssignmentBlock(memberExpression, complexProperty, materializationContext, parentJObject) - : CreateComplexPropertyAssignmentBlock(memberExpression, binaryExpression.Right, complexProperty, materializationContext, parentJObject); - } - } return memberExpression.Assign(Visit(binaryExpression.Right)); } } @@ -297,113 +198,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return base.VisitBinary(binaryExpression); } - private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) - { - var complexJArrayVariable = Variable( - typeof(JArray), - "complexJArray" + ++_currentComplexIndex); - - var assignJArrayVariable = Assign(complexJArrayVariable, - Call( - ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)), - Call(parentJObject, GetItemMethodInfo, - Constant(complexProperty.Name) - ) - ) - ); - - var tempValueBuffer = new ComplexPropertyValueBufferExpression(complexProperty); - var structuralTypeShaperExpression = new StructuralTypeShaperExpression( - complexProperty.ComplexType, - tempValueBuffer, - complexProperty.ClrType.IsNullableType()); // @TODO: Can collection items be null? - - var rawMaterializeExpression = parentVisitor.InjectStructuralTypeMaterializers(structuralTypeShaperExpression); // @TODO: We could also use entityMaterializerSource directly here.. - - var jObjectParameter = Parameter(typeof(JObject), "complexArrayItem" + _currentComplexIndex); - var oldJTokenParametr = jTokenParameter; - jTokenParameter = jObjectParameter; - var materializeExpression = Visit(rawMaterializeExpression); - jTokenParameter = oldJTokenParametr; - - var select = Call( - EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType), - Call( - EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), - complexJArrayVariable), - Lambda(materializeExpression, jObjectParameter)); - - Expression populateExpression = - Call( - PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType), - Constant(complexProperty.GetCollectionAccessor()), - select - ); - - if (complexProperty.IsNullable) - { - populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)), - Default(complexProperty.ClrType), - populateExpression - ); - } - - return Block( - [complexJArrayVariable], - [ - assignJArrayVariable, - memberExpression.Assign(populateExpression) - ] - ); - } - - private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, Expression materializationExpression, IComplexProperty complexProperty, ParameterExpression materializationContext, Expression parentJObject) - { - var complexJObjectVariable = Variable( - typeof(JObject), - "complexJObject" + ++_currentComplexIndex); - var assignComplexJObjectVariable = Assign(complexJObjectVariable, Call( - ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), - Call(parentJObject, GetItemMethodInfo, - Constant(complexProperty.Name) - ) - )); - - if (!_materializationContextCompexPropertyJObjectMappings.TryGetValue(materializationContext, out var complexPropertyJObjectMappings)) - { - complexPropertyJObjectMappings = new(); - _materializationContextCompexPropertyJObjectMappings[materializationContext] = complexPropertyJObjectMappings; - } - - complexPropertyJObjectMappings.Add(complexProperty, complexJObjectVariable); - - BlockExpression materializationBlock; - if (materializationExpression is ConditionalExpression condition) - { - materializationBlock = (condition.IfFalse as BlockExpression ?? (BlockExpression)(condition.IfFalse as UnaryExpression).Operand); - materializationExpression = Condition( - Equal(complexJObjectVariable, Constant(null)), - condition.IfTrue, - Convert(materializationBlock, condition.Type)); - } - else - { - materializationBlock = (BlockExpression)materializationExpression; - } - - var instanceParameter = materializationBlock.Variables.First(v => v.Type == complexProperty.ComplexType.ClrType); - _instanceTypeBaseMappings.Add(instanceParameter, complexProperty.ComplexType); - _concreteTypeInstanceMaterializationContextMappings.Add(instanceParameter, materializationContext); - - materializationExpression = Visit(materializationExpression); - - return Block( - [complexJObjectVariable], - assignComplexJObjectVariable, - memberExpression.Assign(materializationExpression) - ); - } - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { var method = methodCallExpression.Method; @@ -422,15 +216,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else { - var materializationContext = (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object; - if (property.DeclaringType is IComplexType complexType) - { - innerExpression = _materializationContextCompexPropertyJObjectMappings[materializationContext][complexType.ComplexProperty]; - } - else - { - innerExpression = _materializationContextBindings[materializationContext]; - } + innerExpression = _materializationContextBindings[ + (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; } return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); @@ -808,7 +595,7 @@ private static Expression AddToCollectionNavigation( relatedEntity, Constant(true)); - private static readonly MethodInfo PopulateCollectionMethodInfo + public static readonly MethodInfo PopulateCollectionMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() .GetDeclaredMethod(nameof(PopulateCollection)); @@ -1054,42 +841,5 @@ private static Expression ConvertJTokenToType(Expression jTokenExpression, Type private static T SafeToObjectWithSerializer(JToken token) => token == null || token.Type == JTokenType.Null ? default : token.ToObject(CosmosClientWrapper.Serializer); - - private sealed class SwitchCaseReturnValueExtractorExpressionVisitor(Type clrType) : ExpressionVisitor - { - private ParameterExpression _result; - - public Expression Extract(Expression expression) - { - _result = null; - Visit(expression); - return _result; - } - - protected override Expression VisitBinary(BinaryExpression node) - { - if (node.NodeType == ExpressionType.Assign && node.Type.IsAssignableTo(clrType)) - { - _result = (ParameterExpression)node.Left; - return node; - } - - return base.VisitBinary(node); - } - } - - private sealed class ComplexPropertyValueBufferExpression : Expression - { - public ComplexPropertyValueBufferExpression(IComplexProperty complexProperty) - { - ComplexProperty = complexProperty; - } - - public override Type Type => typeof(ValueBuffer); - - public override ExpressionType NodeType => ExpressionType.Extension; - - public IComplexProperty ComplexProperty { get; } - } } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 072401c9005..b80eea5bc11 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -22,6 +22,7 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor( IQuerySqlGeneratorFactory querySqlGeneratorFactory) : ShapedQueryCompilingExpressionVisitor(dependencies, cosmosQueryCompilationContext) { + private ParameterExpression _parentJObject; private readonly Type _contextType = cosmosQueryCompilationContext.ContextType; private readonly bool _threadSafetyChecksEnabled = dependencies.CoreSingletonOptions.AreThreadSafetyChecksEnabled; @@ -39,6 +40,7 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery } var jTokenParameter = Parameter(typeof(JToken), "jToken"); + _parentJObject = jTokenParameter; var shaperBody = shapedQueryExpression.ShaperExpression; @@ -67,7 +69,6 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery } shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor( - this, selectExpression, jTokenParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) .Visit(shaperBody); @@ -171,4 +172,147 @@ private static PartitionKey GeneratePartitionKey( return builder.Build(); } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override void AddStructuralTypeInitialization(StructuralTypeShaperExpression shaper, ParameterExpression instanceVariable, List variables, List expressions) + { + foreach (var complexProperty in shaper.StructuralType.GetComplexProperties()) + { + var member = MakeMemberAccess(instanceVariable, complexProperty.GetMemberInfo(true, true)); + if (complexProperty.IsCollection) + { + expressions.Add(CreateComplexCollectionAssignmentBlock(member, complexProperty)); + } + else + { + expressions.Add(CreateComplexPropertyAssignmentBlock(member, complexProperty)); + } + } + } + + private int _currentComplexIndex; + + private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty) + { + var jObjectVariable = Parameter(typeof(JObject), "complexJObject" + ++_currentComplexIndex); + var assignJObjectVariable = Assign(jObjectVariable, + Call( + CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), + Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo, + Constant(complexProperty.Name) + ) + ) + ); + + var materializeExpression = CreateComplexTypeMaterializeExperssion(complexProperty, jObjectVariable); + if (complexProperty.IsNullable) + { + materializeExpression = Condition(Equal(jObjectVariable, Constant(null)), + Default(complexProperty.ClrType.MakeNullable()), + ConvertChecked(materializeExpression, complexProperty.ClrType.MakeNullable()) + ); + } + + return Block( + [jObjectVariable], + [ + assignJObjectVariable, + memberExpression.Assign(materializeExpression) + ] + ); + } + + private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty) + { + var complexJArrayVariable = Variable( + typeof(JArray), + "complexJArray" + ++_currentComplexIndex); + + var assignJArrayVariable = Assign(complexJArrayVariable, + Call( + CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)), + Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo, + Constant(complexProperty.Name) + ) + ) + ); + var jObjectParameter = Parameter(typeof(JObject), "complexJObject" + _currentComplexIndex); + var materializeExpression = CreateComplexTypeMaterializeExperssion(complexProperty, jObjectParameter); + + var select = Call( + EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType), + Call( + EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), + complexJArrayVariable), + Lambda(materializeExpression, jObjectParameter)); + + Expression populateExpression = + Call( + CosmosProjectionBindingRemovingExpressionVisitorBase.PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType), + Constant(complexProperty.GetCollectionAccessor()), + select + ); + + if (complexProperty.IsNullable) + { + populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)), + Default(complexProperty.ClrType.MakeNullable()), + ConvertChecked(populateExpression, complexProperty.ClrType.MakeNullable()) + ); + } + + return Block( + [complexJArrayVariable], + [ + assignJArrayVariable, + memberExpression.Assign(populateExpression) + ] + ); + } + + + private Expression CreateComplexTypeMaterializeExperssion(IComplexProperty complexProperty, ParameterExpression jObjectParameter) + { + var tempValueBuffer = new ComplexPropertyBindingExpression(complexProperty, jObjectParameter); + var structuralTypeShaperExpression = new StructuralTypeShaperExpression( + complexProperty.ComplexType, + tempValueBuffer, + false); + + var oldParentJObject = _parentJObject; + _parentJObject = jObjectParameter; + var materializeExpression = InjectStructuralTypeMaterializers(structuralTypeShaperExpression); + _parentJObject = oldParentJObject; + + if (complexProperty.ComplexType.ClrType.IsNullableType()) // @TODO: Can collection items be null? + { + materializeExpression = Condition(Equal(jObjectParameter, Constant(null)), + Default(complexProperty.ComplexType.ClrType), + materializeExpression + ); + } + + return materializeExpression; + } + + private class ComplexPropertyBindingExpression : Expression + { + public ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) + { + ComplexProperty = complexProperty; + JObjectParameter = jObjectParameter; + } + + public override Type Type => typeof(ValueBuffer); + + public override ExpressionType NodeType => ExpressionType.Extension; + + public IComplexProperty ComplexProperty { get; } + public ParameterExpression JObjectParameter { get; } + } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs b/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs new file mode 100644 index 00000000000..ab262b22477 --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +#pragma warning disable EF1001 // StructuralTypeMaterializerSource is pubternal + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class CosmosStructuralTypeMaterializerSource(StructuralTypeMaterializerSourceDependencies dependencies) + : StructuralTypeMaterializerSource(dependencies) +{ + /// + /// Complex properties are not handled in the initial materialization expression, + /// So we can more easily generate the necessary nested materialization expressions later in CosmosShapedQueryCompilingExpressionVisitor. + /// + protected override bool ReadComplexTypeDirectly(IComplexType complexType) + => false; +} From 900f117af9f2e529ace5de341ddd36adec83a544 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:11:48 +0100 Subject: [PATCH 15/23] Fix formatting and typo --- .../Internal/CosmosProjectionBindingExpressionVisitor.cs | 1 - .../CosmosShapedQueryCompilingExpressionVisitor.cs | 9 ++++----- .../ComplexPropertiesProjectionCosmosTest.cs | 8 +++++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 7dbe2f5d703..6bd4fa65a7d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using System.Linq.Expressions; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index b80eea5bc11..07000581dab 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -209,7 +209,7 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me ) ); - var materializeExpression = CreateComplexTypeMaterializeExperssion(complexProperty, jObjectVariable); + var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectVariable); if (complexProperty.IsNullable) { materializeExpression = Condition(Equal(jObjectVariable, Constant(null)), @@ -242,7 +242,7 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression ) ); var jObjectParameter = Parameter(typeof(JObject), "complexJObject" + _currentComplexIndex); - var materializeExpression = CreateComplexTypeMaterializeExperssion(complexProperty, jObjectParameter); + var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectParameter); var select = Call( EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType), @@ -275,8 +275,7 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression ); } - - private Expression CreateComplexTypeMaterializeExperssion(IComplexProperty complexProperty, ParameterExpression jObjectParameter) + private Expression CreateComplexTypeMaterializeExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) { var tempValueBuffer = new ComplexPropertyBindingExpression(complexProperty, jObjectParameter); var structuralTypeShaperExpression = new StructuralTypeShaperExpression( @@ -289,7 +288,7 @@ private Expression CreateComplexTypeMaterializeExperssion(IComplexProperty compl var materializeExpression = InjectStructuralTypeMaterializers(structuralTypeShaperExpression); _parentJObject = oldParentJObject; - if (complexProperty.ComplexType.ClrType.IsNullableType()) // @TODO: Can collection items be null? + if (complexProperty.ComplexType.ClrType.IsNullableType()) { materializeExpression = Condition(Equal(jObjectParameter, Constant(null)), Default(complexProperty.ComplexType.ClrType), diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs index 6f50a954831..7a8d3cdee64 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -124,6 +124,7 @@ SELECT VALUE c FROM root c """); } + public override async Task Select_required_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_required_nested_on_required_associate(queryTrackingBehavior); @@ -139,6 +140,7 @@ SELECT VALUE c FROM root c """); } + public override async Task Select_optional_nested_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_optional_nested_on_required_associate(queryTrackingBehavior); @@ -170,6 +172,7 @@ SELECT VALUE c FROM root c """); } + public override async Task Select_optional_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) @@ -369,9 +372,9 @@ public override async Task Select_subquery_optional_related_FirstOrDefault(Query await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); } -#endregion Subquery + #endregion Subquery -#region Value types + #region Value types public override async Task Select_root_with_value_types(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_root_with_value_types(queryTrackingBehavior); @@ -395,7 +398,6 @@ ORDER BY c["Id"] """); } - public override async Task Select_nullable_value_type(QueryTrackingBehavior queryTrackingBehavior) { await base.Select_nullable_value_type(queryTrackingBehavior); From 004d37a4b9b0140fd2f6d69e79e1c57716ba5015 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Fri, 19 Dec 2025 17:32:16 +0100 Subject: [PATCH 16/23] Fix typo and formatting --- ...or.CosmosProjectionBindingRemovingExpressionVisitorBase.cs | 2 +- .../Internal/CosmosShapedQueryCompilingExpressionVisitor.cs | 2 +- .../Query/Internal/CosmosStructuralTypeMaterializerSource.cs | 2 +- .../ComplexPropertiesProjectionCosmosTest.cs | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 7874877ef2d..941d5130014 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -145,7 +145,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) break; case MethodCallExpression jObjectMethodCallExpression when jObjectMethodCallExpression.Method.IsGenericMethod && jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo: - // jobject already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages jobject correctly. + // JObject assignment already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages JObject correctly. return binaryExpression; default: throw new UnreachableException(); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 07000581dab..b42649adb84 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -299,7 +299,7 @@ private Expression CreateComplexTypeMaterializeExpression(IComplexProperty compl return materializeExpression; } - private class ComplexPropertyBindingExpression : Expression + private sealed class ComplexPropertyBindingExpression : Expression { public ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) { diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs b/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs index ab262b22477..3eed32b4e39 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosStructuralTypeMaterializerSource.cs @@ -16,7 +16,7 @@ public class CosmosStructuralTypeMaterializerSource(StructuralTypeMaterializerSo { /// /// Complex properties are not handled in the initial materialization expression, - /// So we can more easily generate the necessary nested materialization expressions later in CosmosShapedQueryCompilingExpressionVisitor. + /// so we can more easily generate the necessary nested materialization expressions later in CosmosShapedQueryCompilingExpressionVisitor. /// protected override bool ReadComplexTypeDirectly(IComplexType complexType) => false; diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs index 7a8d3cdee64..d9131e861f6 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -366,10 +366,10 @@ public override async Task Select_subquery_optional_related_FirstOrDefault(Query { if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) { - throw SkipException.ForSkip("Complex type tracking not supported."); + throw SkipException.ForSkip("Complex type tracking not supported."); } - await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); + await AssertTranslationFailed(() => base.Select_subquery_optional_related_FirstOrDefault(queryTrackingBehavior)); } #endregion Subquery From a8eeba9ec9452e2530a837680735f56db66d7f35 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 7 Jan 2026 10:59:56 +0100 Subject: [PATCH 17/23] Conform to code style --- ...osmosProjectionBindingExpressionVisitor.cs | 20 ++++++------ ...ionBindingRemovingExpressionVisitorBase.cs | 4 +-- ...osShapedQueryCompilingExpressionVisitor.cs | 32 ++++++------------- 3 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 6bd4fa65a7d..e64cdeb04a8 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -350,20 +350,20 @@ UnaryExpression unaryExpression throw new InvalidOperationException(CoreStrings.TranslationFailed(memberExpression.Print())); } - Expression NullSafeUpdate(Expression? expression) + Expression NullSafeUpdate(Expression? innerExpression) { - if (expression is null) + if (innerExpression is null) { - return memberExpression.Update(expression); + return memberExpression.Update(innerExpression); } - var expressionValue = Expression.Parameter(expression.Type); - var assignment = Expression.Assign(expressionValue, expression); + var expressionValue = Expression.Parameter(innerExpression.Type); + var assignment = Expression.Assign(expressionValue, innerExpression); - if (expression.Type.IsNullableType() == true + if (innerExpression.Type.IsNullableType() && !memberExpression.Type.IsNullableType() && memberExpression.Expression is MemberExpression innerMember - && innerMember.Type.IsNullableValueType() == true + && innerMember.Type.IsNullableValueType() && memberExpression.Member.Name == nameof(Nullable<>.Value)) { var nullCheck = Expression.Not( @@ -381,7 +381,7 @@ Expression NullSafeUpdate(Expression? expression) Expression updatedMemberExpression = memberExpression.Update(MatchTypes(expressionValue, memberExpression.Expression!.Type)); - if (expression.Type.IsNullableType() == true) + if (innerExpression.Type.IsNullableType()) { var nullableReturnType = memberExpression.Type.MakeNullable(); @@ -391,7 +391,7 @@ Expression NullSafeUpdate(Expression? expression) } Expression nullCheck; - if (expression.Type.IsNullableValueType()) + if (innerExpression.Type.IsNullableValueType()) { // For Nullable, use HasValue property instead of equality comparison // to avoid issues with value types that don't define the == operator @@ -400,7 +400,7 @@ Expression NullSafeUpdate(Expression? expression) } else { - nullCheck = Expression.Equal(expressionValue, Expression.Default(expression.Type)); + nullCheck = Expression.Equal(expressionValue, Expression.Default(innerExpression.Type)); } updatedMemberExpression = Expression.Condition( diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 941d5130014..aafd01a66e8 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -143,8 +143,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) } break; - case MethodCallExpression jObjectMethodCallExpression - when jObjectMethodCallExpression.Method.IsGenericMethod && jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo: + case MethodCallExpression { Method.IsGenericMethod: true } jObjectMethodCallExpression + when jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo: // JObject assignment already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages JObject correctly. return binaryExpression; default: diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index b42649adb84..92c354c7844 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -22,6 +22,7 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor( IQuerySqlGeneratorFactory querySqlGeneratorFactory) : ShapedQueryCompilingExpressionVisitor(dependencies, cosmosQueryCompilationContext) { + private int _currentComplexIndex; private ParameterExpression _parentJObject; private readonly Type _contextType = cosmosQueryCompilationContext.ContextType; private readonly bool _threadSafetyChecksEnabled = dependencies.CoreSingletonOptions.AreThreadSafetyChecksEnabled; @@ -184,19 +185,12 @@ public override void AddStructuralTypeInitialization(StructuralTypeShaperExpress foreach (var complexProperty in shaper.StructuralType.GetComplexProperties()) { var member = MakeMemberAccess(instanceVariable, complexProperty.GetMemberInfo(true, true)); - if (complexProperty.IsCollection) - { - expressions.Add(CreateComplexCollectionAssignmentBlock(member, complexProperty)); - } - else - { - expressions.Add(CreateComplexPropertyAssignmentBlock(member, complexProperty)); - } + expressions.Add(complexProperty.IsCollection + ? CreateComplexCollectionAssignmentBlock(member, complexProperty) + : CreateComplexPropertyAssignmentBlock(member, complexProperty)); } } - private int _currentComplexIndex; - private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty) { var jObjectVariable = Parameter(typeof(JObject), "complexJObject" + ++_currentComplexIndex); @@ -237,10 +231,8 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression Call( CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)), Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo, - Constant(complexProperty.Name) - ) - ) - ); + Constant(complexProperty.Name)))); + var jObjectParameter = Parameter(typeof(JObject), "complexJObject" + _currentComplexIndex); var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectParameter); @@ -299,19 +291,13 @@ private Expression CreateComplexTypeMaterializeExpression(IComplexProperty compl return materializeExpression; } - private sealed class ComplexPropertyBindingExpression : Expression + private sealed class ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) : Expression { - public ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) - { - ComplexProperty = complexProperty; - JObjectParameter = jObjectParameter; - } - public override Type Type => typeof(ValueBuffer); public override ExpressionType NodeType => ExpressionType.Extension; - public IComplexProperty ComplexProperty { get; } - public ParameterExpression JObjectParameter { get; } + public IComplexProperty ComplexProperty { get; } = complexProperty; + public ParameterExpression JObjectParameter { get; } = jObjectParameter; } } From 3698238eccf40c57ad9b84389dac710bda6f29ea Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Wed, 7 Jan 2026 11:21:44 +0100 Subject: [PATCH 18/23] Add comment to explain check --- .../Internal/CosmosProjectionBindingExpressionVisitor.cs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index e64cdeb04a8..936b034cbe1 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Diagnostics.CodeAnalysis; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Cosmos.Internal; @@ -360,12 +361,16 @@ Expression NullSafeUpdate(Expression? innerExpression) var expressionValue = Expression.Parameter(innerExpression.Type); var assignment = Expression.Assign(expressionValue, innerExpression); + // Special case for when query is projecting 'nullable.Value' where 'nullable' is of type Nullable + // In this case we return default(T) when 'nullable' is null if (innerExpression.Type.IsNullableType() && !memberExpression.Type.IsNullableType() - && memberExpression.Expression is MemberExpression innerMember - && innerMember.Type.IsNullableValueType() + && memberExpression.Expression is MemberExpression outerMember + && outerMember.Type.IsNullableValueType() && memberExpression.Member.Name == nameof(Nullable<>.Value)) { + // Use HasValue property instead of equality comparison + // to avoid issues with value types that don't define the == operator var nullCheck = Expression.Not( Expression.Property(expressionValue, nameof(Nullable<>.HasValue))); var conditionalExpression = Expression.Condition( From 1702ac9f0f456a6f33e331aee530acdd99b04e59 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Fri, 9 Jan 2026 16:04:35 +0100 Subject: [PATCH 19/23] Conform code style don't put closing parentheses on their own lines --- ...smosShapedQueryCompilingExpressionVisitor.cs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 92c354c7844..b727e42d12f 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -198,18 +198,14 @@ private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression me Call( CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)), Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo, - Constant(complexProperty.Name) - ) - ) - ); + Constant(complexProperty.Name)))); var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectVariable); if (complexProperty.IsNullable) { materializeExpression = Condition(Equal(jObjectVariable, Constant(null)), Default(complexProperty.ClrType.MakeNullable()), - ConvertChecked(materializeExpression, complexProperty.ClrType.MakeNullable()) - ); + ConvertChecked(materializeExpression, complexProperty.ClrType.MakeNullable())); } return Block( @@ -247,15 +243,13 @@ private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression Call( CosmosProjectionBindingRemovingExpressionVisitorBase.PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType), Constant(complexProperty.GetCollectionAccessor()), - select - ); + select); if (complexProperty.IsNullable) { populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)), Default(complexProperty.ClrType.MakeNullable()), - ConvertChecked(populateExpression, complexProperty.ClrType.MakeNullable()) - ); + ConvertChecked(populateExpression, complexProperty.ClrType.MakeNullable())); } return Block( @@ -284,8 +278,7 @@ private Expression CreateComplexTypeMaterializeExpression(IComplexProperty compl { materializeExpression = Condition(Equal(jObjectParameter, Constant(null)), Default(complexProperty.ComplexType.ClrType), - materializeExpression - ); + materializeExpression); } return materializeExpression; From 83ea558857a7ec4395a4ba6a18f1c8eead1c698c Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 27 Jan 2026 20:27:37 +0100 Subject: [PATCH 20/23] Remove unused NoSyncTest --- .../ComplexProperties/ComplexPropertiesCosmosFixture.cs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs index dbfbd64ea0a..bedf8b77cf2 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs @@ -15,12 +15,6 @@ public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder build => base.AddOptions(builder) .ConfigureWarnings(w => w.Ignore(CosmosEventId.NoPartitionKeyDefined).Ignore(CoreEventId.MappedEntityTypeIgnoredWarning)); - public Task NoSyncTest(bool async, Func testCode) - => CosmosTestHelpers.Instance.NoSyncTest(async, testCode); - - public void NoSyncTest(Action testCode) - => CosmosTestHelpers.Instance.NoSyncTest(testCode); - protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context) { base.OnModelCreating(modelBuilder, context); From 46c1dfaba9c580272ae5136955b72b98153cb673 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 27 Jan 2026 20:42:40 +0100 Subject: [PATCH 21/23] Remove unneeded HasNoDiscriminator --- .../ComplexProperties/ComplexPropertiesCosmosFixture.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs index bedf8b77cf2..575c2eaecf1 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesCosmosFixture.cs @@ -22,11 +22,9 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con modelBuilder.Ignore(); modelBuilder.Entity() - .ToContainer("RootEntities") - .HasNoDiscriminator(); + .ToContainer("RootEntities"); modelBuilder.Entity() - .ToContainer("ValueRootEntities") - .HasNoDiscriminator(); + .ToContainer("ValueRootEntities"); } } From 302f2de572b577739f78e5c009503f63934a0550 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 27 Jan 2026 20:47:50 +0100 Subject: [PATCH 22/23] Remove skips for tracking queries --- .../ComplexPropertiesProjectionCosmosTest.cs | 94 ++----------------- 1 file changed, 6 insertions(+), 88 deletions(-) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs index d9131e861f6..07a6105b37a 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Associations/ComplexProperties/ComplexPropertiesProjectionCosmosTest.cs @@ -97,11 +97,6 @@ public override async Task Select_associate(QueryTrackingBehavior queryTrackingB { await base.Select_associate(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -113,11 +108,6 @@ public override async Task Select_optional_associate(QueryTrackingBehavior query { await base.Select_optional_associate(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -129,11 +119,6 @@ public override async Task Select_required_nested_on_required_associate(QueryTra { await base.Select_required_nested_on_required_associate(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -145,11 +130,6 @@ public override async Task Select_optional_nested_on_required_associate(QueryTra { await base.Select_optional_nested_on_required_associate(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -159,11 +139,6 @@ FROM root c public override async Task Select_required_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.Select_required_nested_on_optional_associate(queryTrackingBehavior); AssertSql( @@ -175,21 +150,13 @@ FROM root c public override async Task Select_optional_nested_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.Select_optional_nested_on_optional_associate(queryTrackingBehavior); - if (queryTrackingBehavior is not QueryTrackingBehavior.TrackAll) - { - AssertSql( - """ + AssertSql( + """ SELECT VALUE c FROM root c """); - } } public override Task Select_required_associate_via_optional_navigation(QueryTrackingBehavior queryTrackingBehavior) @@ -200,11 +167,6 @@ public override async Task Select_unmapped_associate_scalar_property(QueryTracki { await base.Select_unmapped_associate_scalar_property(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -232,11 +194,6 @@ public override async Task Select_associate_collection(QueryTrackingBehavior que { await base.Select_associate_collection(queryTrackingBehavior); - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - AssertSql( """ SELECT VALUE c @@ -247,11 +204,6 @@ ORDER BY c["Id"] public override async Task Select_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.Select_nested_collection_on_required_associate(queryTrackingBehavior); AssertSql( @@ -264,11 +216,6 @@ ORDER BY c["Id"] public override async Task Select_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.Select_nested_collection_on_optional_associate(queryTrackingBehavior); AssertSql( @@ -282,11 +229,6 @@ ORDER BY c["Id"] [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_associate_collection(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.SelectMany_associate_collection(queryTrackingBehavior); AssertSql( @@ -300,11 +242,6 @@ JOIN a IN c["AssociateCollection"] [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_nested_collection_on_required_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.SelectMany_nested_collection_on_required_associate(queryTrackingBehavior); AssertSql( @@ -318,11 +255,6 @@ JOIN n IN c["RequiredAssociate"]["NestedCollection"] [ConditionalTheory(Skip = "TODO: Query projection")] public override async Task SelectMany_nested_collection_on_optional_associate(QueryTrackingBehavior queryTrackingBehavior) { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - await base.SelectMany_nested_collection_on_optional_associate(queryTrackingBehavior); AssertSql( @@ -352,25 +284,11 @@ FROM root c #region Subquery - public override async Task Select_subquery_required_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) - { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - - await AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); - } + public override Task Select_subquery_required_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) + => AssertTranslationFailed(() => base.Select_subquery_required_related_FirstOrDefault(queryTrackingBehavior)); - public override async Task Select_subquery_optional_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) - { - if (queryTrackingBehavior is QueryTrackingBehavior.TrackAll) - { - throw SkipException.ForSkip("Complex type tracking not supported."); - } - - await AssertTranslationFailed(() => base.Select_subquery_optional_related_FirstOrDefault(queryTrackingBehavior)); - } + public override Task Select_subquery_optional_related_FirstOrDefault(QueryTrackingBehavior queryTrackingBehavior) + => AssertTranslationFailed(() => base.Select_subquery_optional_related_FirstOrDefault(queryTrackingBehavior)); #endregion Subquery From 722a9c076a177403d464a612583b63a17c787ee6 Mon Sep 17 00:00:00 2001 From: JoasE <32096708+JoasE@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:07:00 +0100 Subject: [PATCH 23/23] Fix check for NullSafeUpdate --- .../Internal/CosmosProjectionBindingExpressionVisitor.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 936b034cbe1..8c53bec67ee 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -363,11 +363,11 @@ Expression NullSafeUpdate(Expression? innerExpression) // Special case for when query is projecting 'nullable.Value' where 'nullable' is of type Nullable // In this case we return default(T) when 'nullable' is null + var member = memberExpression.Member; if (innerExpression.Type.IsNullableType() && !memberExpression.Type.IsNullableType() - && memberExpression.Expression is MemberExpression outerMember - && outerMember.Type.IsNullableValueType() - && memberExpression.Member.Name == nameof(Nullable<>.Value)) + && member is { Name: nameof(Nullable<>.Value), DeclaringType.IsGenericType: true } + && member.DeclaringType.GetGenericTypeDefinition() == typeof(Nullable<>)) { // Use HasValue property instead of equality comparison // to avoid issues with value types that don't define the == operator