From 203e17ed95d3cf12ad535a802af8c0c819046923 Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Mon, 7 Feb 2022 18:14:25 +0200 Subject: [PATCH] Fix casting properties from object type in LINQ --- .../NHSpecificTest/GH3005/FixtureByCode.cs | 71 +++++++++++++++++++ .../NHSpecificTest/GH3005/Entity.cs | 11 +++ .../NHSpecificTest/GH3005/FixtureByCode.cs | 59 +++++++++++++++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 26 ++++--- 4 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH3005/FixtureByCode.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH3005/Entity.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH3005/FixtureByCode.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH3005/FixtureByCode.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH3005/FixtureByCode.cs new file mode 100644 index 00000000000..dd268f6036b --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH3005/FixtureByCode.cs @@ -0,0 +1,71 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH3005 +{ + using System.Threading.Tasks; + [TestFixture] + public class ByCodeFixtureAsync : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.Duration); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var e1 = new Entity { Name = "Bob", Duration = TimeSpan.FromMinutes(1) }; + session.Save(e1); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.CreateQuery("delete from System.Object").ExecuteUpdate(); + + transaction.Commit(); + } + } + + [Test] + public async Task CanCastFromObjectAsync() + { + using (var session = OpenSession()) + { + var result = await (session.Query().Select(x => (TimeSpan)(object)x.Duration).FirstOrDefaultAsync()); + + Assert.That(result, Is.EqualTo(TimeSpan.FromMinutes(1))); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH3005/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH3005/Entity.cs new file mode 100644 index 00000000000..eddd37a5ff9 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH3005/Entity.cs @@ -0,0 +1,11 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH3005 +{ + class Entity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual TimeSpan Duration { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH3005/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH3005/FixtureByCode.cs new file mode 100644 index 00000000000..cfb455d750e --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH3005/FixtureByCode.cs @@ -0,0 +1,59 @@ +using System; +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH3005 +{ + [TestFixture] + public class ByCodeFixture : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class(rc => + { + rc.Id(x => x.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(x => x.Name); + rc.Property(x => x.Duration); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var e1 = new Entity { Name = "Bob", Duration = TimeSpan.FromMinutes(1) }; + session.Save(e1); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.CreateQuery("delete from System.Object").ExecuteUpdate(); + + transaction.Commit(); + } + } + + [Test] + public void CanCastFromObject() + { + using (var session = OpenSession()) + { + var result = session.Query().Select(x => (TimeSpan)(object)x.Duration).FirstOrDefault(); + + Assert.That(result, Is.EqualTo(TimeSpan.FromMinutes(1))); + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 8366d42f6ff..2712e7a0f38 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -245,9 +245,11 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) // otherwise the result may be incorrect. In SQL Server avg always returns int // when the argument is int. var hqlExpression = VisitExpression(expression.Expression).AsExpression(); - hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _) - ? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type) - : _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type); + hqlExpression = IsCastRequired(expression.Expression, expression.Type, out var needTransparentCast) + ? _hqlTreeBuilder.Cast(hqlExpression, expression.Type) + : needTransparentCast + ? _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type) + : hqlExpression; // In Oracle the avg function can return a number with up to 40 digits which cannot be retrieved from the data reader due to the lack of such // numeric type in .NET. In order to avoid that we have to add a cast to trim the number so that it can be converted into a .NET numeric type. @@ -532,10 +534,10 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) castType = expression.Type; } - return IsCastRequired(expression.Operand, castType, out var existType) && castable + return IsCastRequired(expression.Operand, castType, out var needTransparentCast) && castable ? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), castType) // Make a transparent cast when an IType exists, so that it can be used to retrieve the value from the data reader - : existType && HqlIdent.SupportsType(castType) + : needTransparentCast ? _hqlTreeBuilder.TransparentCast(VisitExpression(expression.Operand).AsExpression(), castType) : VisitExpression(expression.Operand); } @@ -643,12 +645,16 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression) return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree); } - private bool IsCastRequired(Expression expression, System.Type toType, out bool existType) + private bool IsCastRequired(Expression expression, System.Type toType, out bool needTransparentCast) { - existType = false; - return toType != typeof(object) && - expression.Type.UnwrapIfNullable() != toType.UnwrapIfNullable() && - IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType); + needTransparentCast = + toType != typeof(object) + && expression.Type != typeof(object) + && expression.Type != toType + && HqlIdent.SupportsType(toType) + && expression.Type.UnwrapIfNullable() != toType.UnwrapIfNullable(); + + return needTransparentCast && IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out needTransparentCast); } private bool IsCastRequired(IType type, IType toType, out bool existType)