From 27c673ddfc5fcf7bbc24bb23003a73fd7c41dde7 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 5 Jan 2026 20:43:10 +0100 Subject: [PATCH] Clean up NpgsqlSqlTranslatingExpressionVisitor Especially around NativeAOT-friendly patterns. --- .../NpgsqlSqlTranslatingExpressionVisitor.cs | 467 +++++++++--------- 1 file changed, 231 insertions(+), 236 deletions(-) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 7211b966c..cae5c074c 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -17,66 +17,29 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor +public class NpgsqlSqlTranslatingExpressionVisitor( + RelationalSqlTranslatingExpressionVisitorDependencies dependencies, + QueryCompilationContext queryCompilationContext, + QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) + : RelationalSqlTranslatingExpressionVisitor(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor) { - private readonly QueryCompilationContext _queryCompilationContext; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator; + private readonly QueryCompilationContext _queryCompilationContext = queryCompilationContext; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource = dependencies.TypeMappingSource; + private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator + = ((NpgsqlMemberTranslatorProvider)dependencies.MemberTranslatorProvider).JsonPocoTranslator; - private readonly RelationalTypeMapping _timestampMapping; - private readonly RelationalTypeMapping _timestampTzMapping; + private readonly RelationalTypeMapping _timestampMapping = dependencies.TypeMappingSource.FindMapping("timestamp without time zone")!; + private readonly RelationalTypeMapping _timestampTzMapping = dependencies.TypeMappingSource.FindMapping("timestamp with time zone")!; private static Type? _nodaTimePeriodType; - private static readonly ConstructorInfo DateTimeCtor1 = - typeof(DateTime).GetConstructor([typeof(int), typeof(int), typeof(int)])!; - - private static readonly ConstructorInfo DateTimeCtor2 = - typeof(DateTime).GetConstructor([typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int)])!; - - private static readonly ConstructorInfo DateTimeCtor3 = - typeof(DateTime).GetConstructor( - [typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(DateTimeKind)])!; - - private static readonly ConstructorInfo DateOnlyCtor = - typeof(DateOnly).GetConstructor([typeof(int), typeof(int), typeof(int)])!; - - private static readonly MethodInfo StringStartsWithMethod - = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!; - - private static readonly MethodInfo StringEndsWithMethod - = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!; - - private static readonly MethodInfo StringContainsMethod - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!; - private static readonly MethodInfo EscapeLikePatternParameterMethod = typeof(NpgsqlSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!; // Note: This is the PostgreSQL default and does not need to be explicitly specified private const char LikeEscapeChar = '\\'; - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlSqlTranslatingExpressionVisitor( - RelationalSqlTranslatingExpressionVisitorDependencies dependencies, - QueryCompilationContext queryCompilationContext, - QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) - : base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor) - { - _queryCompilationContext = queryCompilationContext; - _sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory; - _jsonPocoTranslator = ((NpgsqlMemberTranslatorProvider)Dependencies.MemberTranslatorProvider).JsonPocoTranslator; - _typeMappingSource = dependencies.TypeMappingSource; - _timestampMapping = _typeMappingSource.FindMapping("timestamp without time zone")!; - _timestampTzMapping = _typeMappingSource.FindMapping("timestamp with time zone")!; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -144,6 +107,51 @@ bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] o /// protected override Expression VisitUnary(UnaryExpression unaryExpression) { + switch (unaryExpression.NodeType) + { + // We have row value comparison methods such as EF.Functions.GreaterThan, which accept two ValueTuples/Tuples. + // Since they accept ITuple parameters, the arguments have a Convert node casting up from the concrete argument to ITuple; + // this node causes translation failure in RelationalSqlTranslatingExpressionVisitor, so unwrap here. + case ExpressionType.Convert + when unaryExpression.Type == typeof(ITuple) && unaryExpression.Operand.Type.IsAssignableTo(typeof(ITuple)): + return Visit(unaryExpression.Operand); + + // We map both IPAddress and NpgsqlInet to PG inet, and translate many methods accepting NpgsqlInet, so ignore casts from + // IPAddress to NpgsqlInet. + // On the PostgreSQL side, cidr is also implicitly convertible to inet, and at the ADO.NET level NpgsqlCidr has a similar + // implicit conversion operator to NpgsqlInet. So remove that cast as well. + case ExpressionType.Convert + when unaryExpression.Type == typeof(NpgsqlInet) + && (unaryExpression.Operand.Type == typeof(IPAddress) + || unaryExpression.Operand.Type == typeof(IPNetwork) +#pragma warning disable CS0618 // NpgsqlCidr is obsolete, replaced by .NET IPNetwork + || unaryExpression.Operand.Type == typeof(NpgsqlCidr)): +#pragma warning restore CS0618 + return Visit(unaryExpression.Operand); + } + + // We map both IPAddress and NpgsqlInet to PG inet, and translate many methods accepting NpgsqlInet, so ignore casts from + // IPAddress to NpgsqlInet. + // On the PostgreSQL side, cidr is also implicitly convertible to inet, and at the ADO.NET level NpgsqlCidr has a similar + // implicit conversion operator to NpgsqlInet. So remove that cast as well. Note that we do this before calling base.VisitUnary + // since otherwise the translation succeeds, but has an unneeded explicit cast to inet. + if (unaryExpression.NodeType is ExpressionType.Convert + && unaryExpression.Type == typeof(NpgsqlInet) + && (unaryExpression.Operand.Type == typeof(IPAddress) + || unaryExpression.Operand.Type == typeof(IPNetwork) +#pragma warning disable CS0618 // NpgsqlCidr is obsolete, replaced by .NET IPNetwork + || unaryExpression.Operand.Type == typeof(NpgsqlCidr))) +#pragma warning restore CS0618 + { + return Visit(unaryExpression.Operand); + } + + if (base.VisitUnary(unaryExpression) is var translation + && translation != QueryCompilationContext.NotTranslatedExpression) + { + return translation; + } + switch (unaryExpression.NodeType) { case ExpressionType.ArrayLength: @@ -166,37 +174,18 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) } // Attempt to translate Length on a JSON POCO array - if (_jsonPocoTranslator.TranslateArrayLength(sqlOperand) is SqlExpression translation) + if (_jsonPocoTranslator.TranslateArrayLength(sqlOperand) is SqlExpression translation2) { - return translation; + return translation2; } // Note that Length over PG arrays (not within JSON) gets translated by QueryableMethodTranslatingEV, since arrays are // primitive collections break; - - // We have row value comparison methods such as EF.Functions.GreaterThan, which accept two ValueTuples/Tuples. - // Since they accept ITuple parameters, the arguments have a Convert node casting up from the concrete argument to ITuple; - // this node causes translation failure in RelationalSqlTranslatingExpressionVisitor, so unwrap here. - case ExpressionType.Convert - when unaryExpression.Type == typeof(ITuple) && unaryExpression.Operand.Type.IsAssignableTo(typeof(ITuple)): - return Visit(unaryExpression.Operand); - - // We map both IPAddress and NpgsqlInet to PG inet, and translate many methods accepting NpgsqlInet, so ignore casts from - // IPAddress to NpgsqlInet. - // On the PostgreSQL side, cidr is also implicitly convertible to inet, and at the ADO.NET level NpgsqlCidr has a similar - // implicit conversion operator to NpgsqlInet. So remove that cast as well. - case ExpressionType.Convert - when unaryExpression.Type == typeof(NpgsqlInet) - && (unaryExpression.Operand.Type == typeof(IPAddress) - || unaryExpression.Operand.Type == typeof(IPNetwork) -#pragma warning disable CS0618 // NpgsqlCidr is obsolete, replaced by .NET IPNetwork - || unaryExpression.Operand.Type == typeof(NpgsqlCidr)): -#pragma warning restore CS0618 - return Visit(unaryExpression.Operand); } - return base.VisitUnary(unaryExpression); + // return base.VisitUnary(unaryExpression); + return QueryCompilationContext.NotTranslatedExpression; } /// @@ -207,12 +196,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) /// protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) { - if (base.VisitNewArray(newArrayExpression) is SqlExpression visitedNewArrayExpression) + if (base.VisitNewArray(newArrayExpression) is var translation + && translation != QueryCompilationContext.NotTranslatedExpression) { - return visitedNewArrayExpression; + return translation; } - if (newArrayExpression.NodeType == ExpressionType.NewArrayInit) + if (newArrayExpression.NodeType is ExpressionType.NewArrayInit) { var visitedExpressions = new SqlExpression[newArrayExpression.Expressions.Count]; for (var i = 0; i < newArrayExpression.Expressions.Count; i++) @@ -241,62 +231,53 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio /// protected override Expression VisitBinary(BinaryExpression binaryExpression) { - switch (binaryExpression.NodeType) + Debug.Assert( + binaryExpression.NodeType != ExpressionType.ArrayIndex, + "During preprocessing, ArrayIndex and List[] get normalized to ElementAt; see NpgsqlArrayTranslator. Should never see ArrayIndex."); + + // We pattern match date subtraction before calling base.VisitBinary() as that would produce the default a - b translation, but that + // yields the number of days as an integer; this is incompatible with the .NET side, where LocalDate - LocalDate = Period, and Period + // is mapped to PG interval. So we override subtraction to add make_interval. + if (binaryExpression.NodeType is ExpressionType.Subtract + && binaryExpression.Left.Type.UnwrapNullableType().FullName == "NodaTime.LocalDate" + && binaryExpression.Right.Type.UnwrapNullableType().FullName == "NodaTime.LocalDate") { - case ExpressionType.Subtract - when binaryExpression.Left.Type.UnwrapNullableType().FullName == "NodaTime.LocalDate" - && binaryExpression.Right.Type.UnwrapNullableType().FullName == "NodaTime.LocalDate": + if (TranslationFailed(binaryExpression.Left, Visit(TryRemoveImplicitConvert(binaryExpression.Left)), out var sqlLeft) + || TranslationFailed(binaryExpression.Right, Visit(TryRemoveImplicitConvert(binaryExpression.Right)), out var sqlRight)) { - if (TranslationFailed(binaryExpression.Left, Visit(TryRemoveImplicitConvert(binaryExpression.Left)), out var sqlLeft) - || TranslationFailed(binaryExpression.Right, Visit(TryRemoveImplicitConvert(binaryExpression.Right)), out var sqlRight)) - { - return QueryCompilationContext.NotTranslatedExpression; - } + return QueryCompilationContext.NotTranslatedExpression; + } - var subtraction = _sqlExpressionFactory.MakeBinary( - ExpressionType.Subtract, sqlLeft!, sqlRight!, _typeMappingSource.FindMapping(typeof(int)))!; + var subtraction = _sqlExpressionFactory.MakeBinary( + ExpressionType.Subtract, sqlLeft!, sqlRight!, _typeMappingSource.FindMapping(typeof(int)))!; - return PgFunctionExpression.CreateWithNamedArguments( - "make_interval", - [subtraction], - ["days"], - nullable: true, - argumentsPropagateNullability: TrueArrays[1], - builtIn: true, - _nodaTimePeriodType ??= binaryExpression.Left.Type.Assembly.GetType("NodaTime.Period")!, - typeMapping: null); - - // Note: many other date/time arithmetic operators are fully supported as-is by PostgreSQL - see NpgsqlSqlExpressionFactory - } + return PgFunctionExpression.CreateWithNamedArguments( + "make_interval", + [subtraction], + ["days"], + nullable: true, + argumentsPropagateNullability: TrueArrays[1], + builtIn: true, + _nodaTimePeriodType ??= binaryExpression.Left.Type.Assembly.GetType("NodaTime.Period")!, + typeMapping: null); - case ExpressionType.ArrayIndex: - { - // During preprocessing, ArrayIndex and List[] get normalized to ElementAt; see NpgsqlArrayTranslator - Check.DebugFail( - "During preprocessing, ArrayIndex and List[] get normalized to ElementAt; see NpgsqlArrayTranslator. " - + "Should never see ArrayIndex."); - break; - } + // Note: many other date/time arithmetic operators are fully supported as-is by PostgreSQL - see NpgsqlSqlExpressionFactory } - var translation = base.VisitBinary(binaryExpression); - - switch (translation) + return base.VisitBinary(binaryExpression) switch { // Optimize (x - c) - (y - c) to x - y. // This is particularly useful for DateOnly.DayNumber - DateOnly.DayNumber, which is the way to express DateOnly subtraction // (the subtraction operator isn't defined over DateOnly in .NET). The translation of x.DayNumber is x - DATE '0001-01-01', // so the below is a useful simplification. // TODO: As this is a generic mathematical simplification, we should move it to a generic optimization phase in EF Core. - case SqlBinaryExpression + SqlBinaryExpression { OperatorType: ExpressionType.Subtract, Left: SqlBinaryExpression { OperatorType: ExpressionType.Subtract, Left: var left1, Right: var right1 }, Right: SqlBinaryExpression { OperatorType: ExpressionType.Subtract, Left: var left2, Right: var right2 } - } originalBinary when right1.Equals(right2): - { - return new SqlBinaryExpression(ExpressionType.Subtract, left1, left2, originalBinary.Type, originalBinary.TypeMapping); - } + } originalBinary when right1.Equals(right2) + => new SqlBinaryExpression(ExpressionType.Subtract, left1, left2, originalBinary.Type, originalBinary.TypeMapping), // A somewhat hacky workaround for #2942. // When an optional owned JSON entity is compared to null, we get WHERE (x -> y) IS NULL. @@ -304,43 +285,41 @@ when binaryExpression.Left.Type.UnwrapNullableType().FullName == "NodaTime.Local // further JSON operations may need to be composed. However, when the value extracted is a JSON null, a non-NULL jsonb value is // returned, and comparing that to relational NULL returns false. // Pattern-match this and force the use of ->> by changing the mapping to be a scalar rather than an entity type. - case SqlBinaryExpression + SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual, Left: JsonScalarExpression { TypeMapping: NpgsqlStructuralJsonTypeMapping } operand, Right: SqlConstantExpression { Value: null } - } binary: - { - return binary.Update( + } binary + => binary.Update( new JsonScalarExpression( operand.Json, operand.Path, operand.Type, _typeMappingSource.FindMapping("text"), operand.IsNullable), - binary.Right); - } - case SqlBinaryExpression + binary.Right), + + SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual, Left: SqlConstantExpression { Value: null }, Right: JsonScalarExpression { TypeMapping: NpgsqlStructuralJsonTypeMapping } operand - } binary: - { - return binary.Update( + } binary + => binary.Update( binary.Left, new JsonScalarExpression( - operand.Json, operand.Path, operand.Type, _typeMappingSource.FindMapping("text"), operand.IsNullable)); - } + operand.Json, operand.Path, operand.Type, _typeMappingSource.FindMapping("text"), operand.IsNullable)), + // Unfortunately EF isn't consistent in its representation of X IS NULL in the SQL tree - sometimes it's a SqlUnaryExpression with Equals, // sometimes it's an X = NULL SqlBinaryExpression that later gets transformed to SqlUnaryExpression, in SqlNullabilityProcessor. We recognize // both of these here. - case SqlUnaryExpression + SqlUnaryExpression { Operand: JsonScalarExpression { TypeMapping: NpgsqlStructuralJsonTypeMapping } operand - } unary: - return unary.Update( + } unary + => unary.Update( new JsonScalarExpression( - operand.Json, operand.Path, operand.Type, _typeMappingSource.FindMapping("text"), operand.IsNullable)); - } + operand.Json, operand.Path, operand.Type, _typeMappingSource.FindMapping("text"), operand.IsNullable)), - return translation; + var translation => translation + }; } /// @@ -354,7 +333,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var method = methodCallExpression.Method; // Pattern-match: cube.LowerLeft[index] or cube.UpperRight[index] - // This appears as: get_Item method call on a MemberExpression of LowerLeft/UpperRight + // This appears as: get_Item method call on a MemberExpression of LowerLeft/UpperRight. + // We match this before calling base.VisitMethodCall, because we throw an informative InvalidOperationException when + // we see LowerLeft/UpperRight, and only support get_Item over those. if (methodCallExpression is { Method.Name: "get_Item", @@ -390,43 +371,56 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp typeof(double)); } - // Pattern-match: cube.ToSubset(indexes) - if (method.Name == nameof(NpgsqlCube.ToSubset) - && method.DeclaringType == typeof(NpgsqlCube) - && methodCallExpression.Object is not null) + if (base.VisitMethodCall(methodCallExpression) is var translation + && translation != QueryCompilationContext.NotTranslatedExpression) { - // Translate cube instance and indexes array - if (Visit(methodCallExpression.Object) is not SqlExpression sqlCubeInstance - || Visit(methodCallExpression.Arguments[0]) is not SqlExpression sqlIndexes) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - return TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression; + return translation; } - if (method == StringStartsWithMethod - && TryTranslateStartsEndsWithContains( - methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1)) - { - return translation1; - } + var declaringType = method.DeclaringType; + var @object = methodCallExpression.Object; + var arguments = methodCallExpression.Arguments; - if (method == StringEndsWithMethod - && TryTranslateStartsEndsWithContains( - methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2)) + switch (method.Name) { - return translation2; - } + // Pattern-match: cube.ToSubset(indexes) + case nameof(NpgsqlCube.ToSubset) + when declaringType == typeof(NpgsqlCube) + && @object is not null + && arguments is [Expression indexes]: + { + // Translate cube instance and indexes array + return Visit(@object) is SqlExpression sqlCubeInstance && Visit(indexes) is SqlExpression sqlIndexes + ? TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression + : QueryCompilationContext.NotTranslatedExpression; + } - if (method == StringContainsMethod - && TryTranslateStartsEndsWithContains( - methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3)) - { - return translation3; + // https://learn.microsoft.com/dotnet/api/system.string.startswith#system-string-startswith(system-string) + // https://learn.microsoft.com/dotnet/api/system.string.startswith#system-string-startswith(system-char) + // https://learn.microsoft.com/dotnet/api/system.string.endswith#system-string-endswith(system-string) + // https://learn.microsoft.com/dotnet/api/system.string.endswith#system-string-endswith(system-char) + // https://learn.microsoft.com/dotnet/api/system.string.contains#system-string-contains(system-string) + // https://learn.microsoft.com/dotnet/api/system.string.contains#system-string-contains(system-char) + case nameof(string.StartsWith) or nameof(string.EndsWith) or nameof(string.Contains) + when declaringType == typeof(string) + && @object is not null + && arguments is [Expression value] + && value.Type == typeof(string): + { + return TranslateStartsEndsWithContains( + @object, + value, + method.Name switch + { + nameof(string.StartsWith) => StartsEndsWithContains.StartsWith, + nameof(string.EndsWith) => StartsEndsWithContains.EndsWith, + nameof(string.Contains) => StartsEndsWithContains.Contains, + _ => throw new UnreachableException() + }); + } } - return base.VisitMethodCall(methodCallExpression); + return QueryCompilationContext.NotTranslatedExpression; } private SqlExpression? TranslateCubeToSubset(SqlExpression cubeExpression, SqlExpression indexesExpression) @@ -536,77 +530,93 @@ protected override Expression VisitNew(NewExpression newExpression) : QueryCompilationContext.NotTranslatedExpression; } + var constructor = newExpression.Constructor; + // Translate new DateTime(...) -> make_timestamp/make_date - if (newExpression.Constructor?.DeclaringType == typeof(DateTime)) + if (constructor?.DeclaringType == typeof(DateTime)) { - if (newExpression.Constructor == DateTimeCtor1) + switch (newExpression.Arguments) { - return TryTranslateArguments(out var sqlArguments) - ? _sqlExpressionFactory.Function( - "make_date", sqlArguments, nullable: true, TrueArrays[3], typeof(DateTime), _timestampMapping) - : QueryCompilationContext.NotTranslatedExpression; - } - - if (newExpression.Constructor == DateTimeCtor2) - { - if (!TryTranslateArguments(out var sqlArguments)) + // https://learn.microsoft.com/dotnet/api/system.datetime.-ctor#system-datetime-ctor(system-int32-system-int32-system-int32) + case [var year, var month, var day] + when year.Type == typeof(int) && month.Type == typeof(int) && day.Type == typeof(int): { - return QueryCompilationContext.NotTranslatedExpression; + return TryTranslateArguments(out var sqlArguments) + ? _sqlExpressionFactory.Function( + "make_date", sqlArguments, nullable: true, TrueArrays[3], typeof(DateTime), _timestampMapping) + : QueryCompilationContext.NotTranslatedExpression; } - // DateTime's second component is an int, but PostgreSQL's MAKE_TIMESTAMP accepts a double precision - sqlArguments[5] = _sqlExpressionFactory.Convert(sqlArguments[5], typeof(double)); + // https://learn.microsoft.com/dotnet/api/system.datetime.-ctor#system-datetime-ctor(system-int32-system-int32-system-int32-system-int32-system-int32-system-int32) + case [var year, var month, var day, var hour, var minute, var second] + when year.Type == typeof(int) && month.Type == typeof(int) && day.Type == typeof(int) + && hour.Type == typeof(int) && minute.Type == typeof(int) && second.Type == typeof(int): + { + if (!TryTranslateArguments(out var sqlArguments)) + { + return QueryCompilationContext.NotTranslatedExpression; + } - return _sqlExpressionFactory.Function( - "make_timestamp", sqlArguments, nullable: true, TrueArrays[6], typeof(DateTime), _timestampMapping); - } + // DateTime's second component is an int, but PostgreSQL's MAKE_TIMESTAMP accepts a double precision + sqlArguments[5] = _sqlExpressionFactory.Convert(sqlArguments[5], typeof(double)); - if (newExpression.Constructor == DateTimeCtor3 - && newExpression.Arguments[6] is ConstantExpression { Value : DateTimeKind kind }) - { - if (!TryTranslateArguments(out var sqlArguments)) - { - return QueryCompilationContext.NotTranslatedExpression; + return _sqlExpressionFactory.Function( + "make_timestamp", sqlArguments, nullable: true, TrueArrays[6], typeof(DateTime), _timestampMapping); } - // DateTime's second component is an int, but PostgreSQL's make_timestamp/make_timestamptz accepts a double precision. - // Also chop off the last Kind argument which does not get sent to PostgreSQL - var rewrittenArguments = new List + // https://learn.microsoft.com/dotnet/api/system.datetime.-ctor#system-datetime-ctor(system-int32-system-int32-system-int32-system-int32-system-int32-system-int32-system-datetimekind) + case [var year, var month, var day, var hour, var minute, var second, ConstantExpression { Value: DateTimeKind kind }] + when year.Type == typeof(int) && month.Type == typeof(int) && day.Type == typeof(int) + && hour.Type == typeof(int) && minute.Type == typeof(int) && second.Type == typeof(int): { - sqlArguments[0], - sqlArguments[1], - sqlArguments[2], - sqlArguments[3], - sqlArguments[4], - _sqlExpressionFactory.Convert(sqlArguments[5], typeof(double)) - }; + if (!TryTranslateArguments(out var sqlArguments)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + // DateTime's second component is an int, but PostgreSQL's make_timestamp/make_timestamptz accepts a double precision. + // Also chop off the last Kind argument which does not get sent to PostgreSQL + var rewrittenArguments = new List + { + sqlArguments[0], + sqlArguments[1], + sqlArguments[2], + sqlArguments[3], + sqlArguments[4], + _sqlExpressionFactory.Convert(sqlArguments[5], typeof(double)) + }; + + if (kind == DateTimeKind.Utc) + { + rewrittenArguments.Add(_sqlExpressionFactory.Constant("UTC")); + } - if (kind == DateTimeKind.Utc) - { - rewrittenArguments.Add(_sqlExpressionFactory.Constant("UTC")); + return _sqlExpressionFactory.Function( + kind == DateTimeKind.Utc ? "make_timestamptz" : "make_timestamp", + rewrittenArguments, + nullable: true, + TrueArrays[rewrittenArguments.Count], + typeof(DateTime), + kind == DateTimeKind.Utc ? _timestampTzMapping : _timestampMapping); } - - return _sqlExpressionFactory.Function( - kind == DateTimeKind.Utc ? "make_timestamptz" : "make_timestamp", - rewrittenArguments, - nullable: true, - TrueArrays[rewrittenArguments.Count], - typeof(DateTime), - kind == DateTimeKind.Utc ? _timestampTzMapping : _timestampMapping); } } // Translate new DateOnly(...) -> make_date - if (newExpression.Constructor == DateOnlyCtor) + if (constructor?.DeclaringType == typeof(DateOnly)) { - return TryTranslateArguments(out var sqlArguments) - ? _sqlExpressionFactory.Function( - "make_date", sqlArguments, nullable: true, TrueArrays[3], typeof(DateOnly)) - : QueryCompilationContext.NotTranslatedExpression; + if (newExpression.Arguments is [var year, var month, var day] + && year.Type == typeof(int) && month.Type == typeof(int) && day.Type == typeof(int)) + { + return TryTranslateArguments(out var sqlArguments) + ? _sqlExpressionFactory.Function( + "make_date", sqlArguments, nullable: true, TrueArrays[3], typeof(DateOnly)) + : QueryCompilationContext.NotTranslatedExpression; + } } // Translate new NpgsqlCube(...) -> cube(...) - if (newExpression.Constructor?.DeclaringType == typeof(NpgsqlCube)) + if (constructor?.DeclaringType == typeof(NpgsqlCube)) { if (!TryTranslateArguments(out var sqlArguments)) { @@ -614,7 +624,7 @@ protected override Expression VisitNew(NewExpression newExpression) } var cubeTypeMapping = _typeMappingSource.FindMapping(typeof(NpgsqlCube)); - var cubeParameters = newExpression.Constructor.GetParameters(); + var cubeParameters = constructor.GetParameters(); // Distinguish constructor overloads by parameter patterns switch (cubeParameters) @@ -671,17 +681,11 @@ bool TryTranslateArguments(out SqlExpression[] sqlArguments) #region StartsWith/EndsWith/Contains - private bool TryTranslateStartsEndsWithContains( - Expression instance, - Expression pattern, - StartsEndsWithContains methodType, - [NotNullWhen(true)] out SqlExpression? translation) + private Expression TranslateStartsEndsWithContains(Expression instance, Expression pattern, StartsEndsWithContains methodType) { - if (Visit(instance) is not SqlExpression translatedInstance - || Visit(pattern) is not SqlExpression translatedPattern) + if (Visit(instance) is not SqlExpression translatedInstance || Visit(pattern) is not SqlExpression translatedPattern) { - translation = null; - return false; + return QueryCompilationContext.NotTranslatedExpression;; } var stringTypeMapping = ExpressionExtensions.InferTypeMapping(translatedInstance, translatedPattern); @@ -695,7 +699,7 @@ private bool TryTranslateStartsEndsWithContains( { // The pattern is constant. Aside from null and empty string, we escape all special characters (%, _, \) and send a // simple LIKE - translation = patternConstant.Value switch + return patternConstant.Value switch { null => _sqlExpressionFactory.Like( translatedInstance, @@ -721,8 +725,6 @@ private bool TryTranslateStartsEndsWithContains( _ => throw new UnreachableException() }; - - return true; } case SqlParameterExpression patternParameter: @@ -741,11 +743,9 @@ private bool TryTranslateStartsEndsWithContains( _queryCompilationContext.RegisterRuntimeParameter( $"{patternParameter.Name}_{methodType.ToString().ToLower(CultureInfo.InvariantCulture)}", lambda); - translation = _sqlExpressionFactory.Like( + return _sqlExpressionFactory.Like( translatedInstance, new SqlParameterExpression(escapedPatternParameter.Name!, escapedPatternParameter.Type, stringTypeMapping)); - - return true; } default: @@ -757,7 +757,7 @@ private bool TryTranslateStartsEndsWithContains( // WHERE instance IS NOT NULL AND pattern IS NOT NULL AND LEFT(instance, LEN(pattern)) = pattern // This is less efficient than LIKE (i.e. StartsWith does an index scan instead of seek), but we have no choice. case StartsEndsWithContains.StartsWith or StartsEndsWithContains.EndsWith: - translation = + var translation = _sqlExpressionFactory.Function( methodType is StartsEndsWithContains.StartsWith ? "left" : "right", [ @@ -776,19 +776,17 @@ private bool TryTranslateStartsEndsWithContains( // We compensate for the case where both the instance and the pattern are null (null.StartsWith(null)); a simple // equality would yield true in that case, but we want false. - translation = + return _sqlExpressionFactory.AndAlso( _sqlExpressionFactory.IsNotNull(translatedInstance), _sqlExpressionFactory.AndAlso( _sqlExpressionFactory.IsNotNull(translatedPattern), _sqlExpressionFactory.Equal(translation, translatedPattern))); - break; - // For Contains, just use strpos and check if the result is greater than 0. Note that strpos returns 1 when the pattern // is an empty string, just like .NET Contains (so no need to compensate) case StartsEndsWithContains.Contains: - translation = + return _sqlExpressionFactory.AndAlso( _sqlExpressionFactory.IsNotNull(translatedInstance), _sqlExpressionFactory.AndAlso( @@ -798,13 +796,10 @@ private bool TryTranslateStartsEndsWithContains( "strpos", [translatedInstance, translatedPattern], nullable: true, argumentsPropagateNullability: [true, true], typeof(int)), _sqlExpressionFactory.Constant(0)))); - break; default: throw new UnreachableException(); } - - return true; } }