diff --git a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs new file mode 100644 index 00000000000..c8c9fb3cf03 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs @@ -0,0 +1,18 @@ +using System.Collections; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class DynamicUser : IEnumerable + { + public virtual int Id { get; set; } + + public virtual dynamic Properties { get; set; } + + public virtual IDictionary Settings { get; set; } + + public virtual IEnumerator GetEnumerator() + { + throw new System.NotImplementedException(); + } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index c4cbda23f26..4551ce0e9d8 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable DynamicUsers + { + get { return _session.Query(); } + } + public IQueryable PatientRecords { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c23e667be9b..14096dac912 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -48,10 +48,16 @@ public class User : IUser, IEntity public virtual FeatureSet Features { get; set; } + public virtual User NotMappedUser => this; + public virtual EnumStoredAsString Enum1 { get; set; } + public virtual EnumStoredAsString? NullableEnum1 { get; set; } + public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual EnumStoredAsInt32? NullableEnum2 { get; set; } + public virtual IUser CreatedBy { get; set; } public virtual IUser ModifiedBy { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml new file mode 100644 index 00000000000..1b6775b29c6 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml @@ -0,0 +1,30 @@ + + + + + select * from Users + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index 2764cb70898..f249de9574e 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -24,8 +24,14 @@ + + + + + diff --git a/src/NHibernate.Test/Async/Futures/LinqToFutureValueFixture.cs b/src/NHibernate.Test/Async/Futures/LinqToFutureValueFixture.cs index c56737d4feb..ed0439502b1 100644 --- a/src/NHibernate.Test/Async/Futures/LinqToFutureValueFixture.cs +++ b/src/NHibernate.Test/Async/Futures/LinqToFutureValueFixture.cs @@ -70,7 +70,7 @@ public void ToFutureValueWithSumOnEmptySetThrowsAsync() .Select(x => x.Id) .ToFutureValue(x => x.Sum()); - Assert.That(() => personsSum.GetValueAsync(), Throws.InnerException.TypeOf().Or.InnerException.TypeOf()); + Assert.That(() => personsSum.GetValueAsync(), Throws.TypeOf().Or.InnerException.TypeOf()); } } diff --git a/src/NHibernate.Test/Async/Linq/ConstantTest.cs b/src/NHibernate.Test/Async/Linq/ConstantTest.cs index 215aa6e8dbd..565c181ae62 100644 --- a/src/NHibernate.Test/Async/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Async/Linq/ConstantTest.cs @@ -299,7 +299,7 @@ public async Task DmlPlansAreCachedAsync() } [Test] - public async Task PlansWithNonParameterizedConstantsAreNotCachedAsync() + public async Task PlansWithNonParameterizedConstantsAreCachedAsync() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -314,12 +314,12 @@ public async Task PlansWithNonParameterizedConstantsAreNotCachedAsync() select new { c.CustomerId, c.ContactName, Constant = 1 }).FirstAsync()); Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(1), + "Query plan should be cached."); } [Test] - public async Task PlansWithNonParameterizedConstantsAreNotCachedForExpandedQueryAsync() + public async Task PlansWithNonParameterizedConstantsAreCachedForExpandedQueryAsync() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -335,8 +335,8 @@ public async Task PlansWithNonParameterizedConstantsAreNotCachedForExpandedQuery Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(2), // The second one is for the expanded expression that has two parameters + "Query plan should be cached."); } //GH-2298 - Different Update queries - same query cache plan diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 622a806ed30..6e9355d294c 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -61,5 +61,42 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public async Task ConditionalNavigationPropertyAsync() + { + EnumStoredAsString? type = null; + await (db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToListAsync()); + await (db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToListAsync()); + await (db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToListAsync()); + } + + [Test] + public async Task CanQueryComplexExpressionOnEnumStoredAsStringAsync() + { + var type = EnumStoredAsString.Unspecified; + var query = await ((from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToListAsync()); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 4fbebe3e78b..2fc9b92862c 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -88,6 +88,34 @@ public async Task UsingTwoEntityParametersAsync() 2)); } + [Test] + public async Task UsingEntityEnumerableParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = await (db.DynamicUsers.FirstAsync()); + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1)); + } + + [Test] + public async Task UsingEntityEnumerableListParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {await (db.DynamicUsers.FirstAsync())}; + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1)); + } + [Test] public async Task UsingValueTypeParameterTwiceAsync() { @@ -322,7 +350,7 @@ public async Task UsingTwoParametersInDMLDeleteAsync() { // In case of arrays linqParameterNumber and parameterNumber will be different Assert.That( - GetLinqExpression(query).ParameterValuesByName.Count, + GetLinqExpression(query).NamedParameters.Count, Is.EqualTo(linqParameterNumber ?? parameterNumber), "Linq expression has different number of parameters"); diff --git a/src/NHibernate.Test/Async/Linq/QueryPlanTests.cs b/src/NHibernate.Test/Async/Linq/QueryPlanTests.cs new file mode 100644 index 00000000000..dd3cf906c87 --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/QueryPlanTests.cs @@ -0,0 +1,175 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NHibernate.Dialect; +using NHibernate.Linq; +using NSubstitute; +using NSubstitute.Extensions; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + [TestFixture] + public class QueryPlanTestsAsync : LinqTestCase + { + [Test] + public async Task SelectConstantShouldBeCachedAsync() + { + ClearQueryPlanCache(); + + var c1 = await (db.Customers.Select(o => new {o.CustomerId, Constant = "constant"}).FirstAsync()); + var c2 = await (db.Customers.Select(o => new {o.CustomerId, Constant = "constant2"}).FirstAsync()); + var constant = "constant3"; + var c3 = await (db.Customers.Select(o => new {o.CustomerId, Constant = constant}).FirstAsync()); + constant = "constant4"; + var c4 = await (db.Customers.Select(o => new {o.CustomerId, Constant = constant}).FirstAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + Assert.That(c1.Constant, Is.EqualTo("constant")); + Assert.That(c2.Constant, Is.EqualTo("constant2")); + Assert.That(c3.Constant, Is.EqualTo("constant3")); + Assert.That(c4.Constant, Is.EqualTo("constant4")); + } + + [Test] + public async Task GroupByConstantShouldBeCachedAsync() + { + ClearQueryPlanCache(); + + var c1 = await (db.Customers.GroupBy(o => new {o.CustomerId, Constant = "constant"}).Select(o => o.Key).FirstAsync()); + var c2 = await (db.Customers.GroupBy(o => new {o.CustomerId, Constant = "constant2"}).Select(o => o.Key).FirstAsync()); + var constant = "constant3"; + var c3 = await (db.Customers.GroupBy(o => new {o.CustomerId, Constant = constant}).Select(o => o.Key).FirstAsync()); + constant = "constant4"; + var c4 = await (db.Customers.GroupBy(o => new {o.CustomerId, Constant = constant}).Select(o => o.Key).FirstAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + Assert.That(c1.Constant, Is.EqualTo("constant")); + Assert.That(c2.Constant, Is.EqualTo("constant2")); + Assert.That(c3.Constant, Is.EqualTo("constant3")); + Assert.That(c4.Constant, Is.EqualTo("constant4")); + } + + [Test] + public async Task WithLockShouldBeCachedAsync() + { + ClearQueryPlanCache(); + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + await (db.Customers.WithLock(LockMode.Upgrade).ToListAsync()); + await (db.Customers.WithLock(LockMode.UpgradeNoWait).ToListAsync()); + var lockMode = LockMode.None; + await (db.Customers.WithLock(lockMode).ToListAsync()); + lockMode = LockMode.Read; + await (db.Customers.WithLock(lockMode).ToListAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(4)); + } + + [TestCase(true)] + [TestCase(false)] + public async Task SkipShouldBeCachedAsync(bool supportsVariableLimit) + { + if (!Dialect.SupportsLimit || (supportsVariableLimit && !Dialect.SupportsVariableLimit)) + { + Assert.Ignore(); + } + + ClearQueryPlanCache(); + using (var substitute = SubstituteDialect()) + { + substitute.Value.Configure().SupportsVariableLimit.Returns(supportsVariableLimit); + + var c1 = await (db.Customers.Skip(1).ToListAsync()); + var c2 = await (db.Customers.Skip(2).ToListAsync()); + var skip = 3; + var c3 = await (db.Customers.Skip(skip).ToListAsync()); + skip = 4; + var c4 = await (db.Customers.Skip(skip).ToListAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(c1.Count, Is.Not.EqualTo(c2.Count)); + Assert.That(c2.Count, Is.Not.EqualTo(c3.Count)); + Assert.That(c3.Count, Is.Not.EqualTo(c4.Count)); + Assert.That(queryCache.Count, Is.EqualTo(supportsVariableLimit ? 1 : 4)); + } + } + + [TestCase(true)] + [TestCase(false)] + public async Task TakeShouldBeCachedAsync(bool supportsVariableLimit) + { + if (!Dialect.SupportsLimit || (supportsVariableLimit && !Dialect.SupportsVariableLimit)) + { + Assert.Ignore(); + } + + ClearQueryPlanCache(); + using (var substitute = SubstituteDialect()) + { + substitute.Value.Configure().SupportsVariableLimit.Returns(supportsVariableLimit); + + var c1 = await (db.Customers.Take(1).ToListAsync()); + var c2 = await (db.Customers.Take(2).ToListAsync()); + var skip = 3; + var c3 = await (db.Customers.Take(skip).ToListAsync()); + skip = 4; + var c4 = await (db.Customers.Take(skip).ToListAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(c1.Count, Is.EqualTo(1)); + Assert.That(c2.Count, Is.EqualTo(2)); + Assert.That(c3.Count, Is.EqualTo(3)); + Assert.That(c4.Count, Is.EqualTo(4)); + Assert.That(queryCache.Count, Is.EqualTo(supportsVariableLimit ? 1 : 4)); + } + } + + [Test] + public async Task TrimFunctionShouldNotBeCachedAsync() + { + ClearQueryPlanCache(); + + await (db.Customers.Select(o => new {CustomerId = o.CustomerId.Trim('-')}).FirstAsync()); + await (db.Customers.Select(o => new {CustomerId = o.CustomerId.Trim('+')}).FirstAsync()); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(0)); + } + + [Test] + public async Task SubstringFunctionShouldBeCachedAsync() + { + ClearQueryPlanCache(); + + var queryCache = GetQueryPlanCache(); + var c1 = await (db.Customers.Select(o => new {Name = o.ContactName.Substring(1)}).FirstAsync()); + var c2 = await (db.Customers.Select(o => new {Name = o.ContactName.Substring(2)}).FirstAsync()); + + Assert.That(c1.Name, Is.Not.EqualTo(c2.Name)); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + ClearQueryPlanCache(); + c1 = await (db.Customers.Select(o => new { Name = o.ContactName.Substring(1, 2) }).FirstAsync()); + c2 = await (db.Customers.Select(o => new { Name = o.ContactName.Substring(2, 1) }).FirstAsync()); + + Assert.That(c1.Name, Is.Not.EqualTo(c2.Name)); + Assert.That(queryCache.Count, Is.EqualTo(1)); + } + } +} diff --git a/src/NHibernate.Test/Async/NHSpecificTest/NH3850/MainFixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/NH3850/MainFixture.cs index ecbd2e774d1..166898951f5 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/NH3850/MainFixture.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/NH3850/MainFixture.cs @@ -930,7 +930,7 @@ public async Task LongCountObjectAsync() "Non nullable decimal max has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Max(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.GetValueAsync(cancellationToken), - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal max has failed"); } } @@ -1002,7 +1002,7 @@ public async Task LongCountObjectAsync() "Non nullable decimal min has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Min(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.GetValueAsync(cancellationToken), - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal min has failed"); } } @@ -1017,7 +1017,7 @@ public void SingleOrDefaultBBaseAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1050,7 +1050,7 @@ public void SingleOrDefaultCBaseAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1083,7 +1083,7 @@ public void SingleOrDefaultEAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1146,7 +1146,7 @@ public void SingleOrDefaultGBaseAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1159,7 +1159,7 @@ public void SingleOrDefaultGBaseWithNameAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(dc => dc.Name == SearchName1), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault(dc => dc.Name == SearchName1)); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1172,7 +1172,7 @@ public void SingleOrDefaultObjectAsync() var query = session.Query(); Assert.That(() => query.SingleOrDefaultAsync(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.GetValueAsync(), Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.GetValueAsync(), Throws.InstanceOf(), "Future"); } } @@ -1276,7 +1276,7 @@ public async Task SumObjectAsync() "Non nullable decimal sum has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Sum(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.GetValueAsync(cancellationToken), - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal sum has failed"); } } diff --git a/src/NHibernate.Test/Futures/LinqToFutureValueFixture.cs b/src/NHibernate.Test/Futures/LinqToFutureValueFixture.cs index 244dcea7274..19fc8c020bd 100644 --- a/src/NHibernate.Test/Futures/LinqToFutureValueFixture.cs +++ b/src/NHibernate.Test/Futures/LinqToFutureValueFixture.cs @@ -59,7 +59,7 @@ public void ToFutureValueWithSumOnEmptySetThrows() .Select(x => x.Id) .ToFutureValue(x => x.Sum()); - Assert.That(() => personsSum.Value, Throws.InnerException.TypeOf().Or.InnerException.TypeOf()); + Assert.That(() => personsSum.Value, Throws.TypeOf().Or.InnerException.TypeOf()); } } diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 6b693ddbc4a..96fc00c864b 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -217,12 +217,12 @@ public void ConstantInWhereDoesNotCauseManyKeys() select c); var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi); var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters); - var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1); - var k1 = ExpressionKeyVisitor.Visit(expression, parameters1); + var parameters1 = ExpressionParameterVisitor.Visit(preTransformResult); + var k1 = ExpressionKeyVisitor.Visit(preTransformResult.Expression, parameters1, Sfi); var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters); - var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2); - var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2); + var parameters2 = ExpressionParameterVisitor.Visit(preTransformResult2); + var k2 = ExpressionKeyVisitor.Visit(preTransformResult2.Expression, parameters2, Sfi); Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); @@ -324,7 +324,7 @@ public void DmlPlansAreCached() } [Test] - public void PlansWithNonParameterizedConstantsAreNotCached() + public void PlansWithNonParameterizedConstantsAreCached() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -339,12 +339,12 @@ public void PlansWithNonParameterizedConstantsAreNotCached() select new { c.CustomerId, c.ContactName, Constant = 1 }).First(); Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(1), + "Query plan should be cached."); } [Test] - public void PlansWithNonParameterizedConstantsAreNotCachedForExpandedQuery() + public void PlansWithNonParameterizedConstantsAreCachedForExpandedQuery() { var queryPlanCacheType = typeof(QueryPlanCache); @@ -360,8 +360,8 @@ public void PlansWithNonParameterizedConstantsAreNotCachedForExpandedQuery() Assert.That( cache, - Has.Count.EqualTo(0), - "Query plan should not be cached."); + Has.Count.EqualTo(2), // The second one is for the expanded expression that has two parameters + "Query plan should be cached."); } //GH-2298 - Different Update queries - same query cache plan diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index 4050c7ddb97..aeea060b51e 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -48,5 +48,42 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public void ConditionalNavigationProperty() + { + EnumStoredAsString? type = null; + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToList(); + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToList(); + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToList(); + } + + [Test] + public void CanQueryComplexExpressionOnEnumStoredAsString() + { + var type = EnumStoredAsString.Unspecified; + var query = (from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToList(); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index e047732d7ad..daf14b9cd18 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -34,7 +34,8 @@ protected override string[] Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.DynamicUser.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 920fa565129..190036b2b2e 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -76,6 +76,34 @@ public void UsingTwoEntityParameters() 2); } + [Test] + public void UsingEntityEnumerableParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = db.DynamicUsers.First(); + AssertTotalParameters( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1); + } + + [Test] + public void UsingEntityEnumerableListParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {db.DynamicUsers.First()}; + AssertTotalParameters( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1); + } + [Test] public void UsingValueTypeParameterTwice() { @@ -383,7 +411,7 @@ private void AssertTotalParameters(IQueryable query, int parameterNumber, { // In case of arrays linqParameterNumber and parameterNumber will be different Assert.That( - GetLinqExpression(query).ParameterValuesByName.Count, + GetLinqExpression(query).NamedParameters.Count, Is.EqualTo(linqParameterNumber ?? parameterNumber), "Linq expression has different number of parameters"); diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs new file mode 100644 index 00000000000..2cb87bc50b2 --- /dev/null +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -0,0 +1,444 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Dynamic.Core; +using System.Linq.Expressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Type; +using NUnit.Framework; +using Remotion.Linq.Clauses; + +namespace NHibernate.Test.Linq +{ + public class ParameterTypeLocatorTests : LinqTestCase + { + [Test] + public void AddIntegerTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5", o => o is Int32Type}, + }, + db.Users.Where(o => o.Id + 5 > 2.1), + db.Users.Where(o => 2.1 < 5 + o.Id) + ); + } + + [Test] + public void AddDecimalTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DecimalType}, + {"5.2", o => o is DecimalType}, + }, + db.Users.Where(o => o.Id + 5.2m > 2.1m), + db.Users.Where(o => 2.1m < 5.2m + o.Id) + ); + } + + [Test] + public void SubtractFloatTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5.2", o => o is SingleType}, + }, + db.Users.Where(o => o.Id - 5.2f > 2.1), + db.Users.Where(o => 2.1 < 5.2f - o.Id) + ); + } + + [Test] + public void GreaterThanTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is Int32Type} + }, + db.Users.Where(o => o.Id > 2.1), + db.Users.Where(o => 2.1 > o.Id) + ); + } + + [Test] + public void EqualStringEnumTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void EqualStringTest() + { + AssertResults( + new Dictionary> + { + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City == "London"), + db.Orders.Where(o => "London" == o.ShippingAddress.City) + ); + } + + [Test] + public void EqualEntityTest() + { + var order = new Order(); + AssertResults( + new Dictionary> + { + { + $"value({typeof(Order).FullName})", + o => o is ManyToOneType manyToOne && manyToOne.Name == typeof(Order).FullName + } + }, + db.Orders.Where(o => o == order), + db.Orders.Where(o => order == o) + ); + } + + [Test] + public void DoubleEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large && o.Enum2 == EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High == o.Enum2 && EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void NotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1) + ); + } + + [Test] + public void DoubleNotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large || o.NullableEnum2 != EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High != o.NullableEnum2 || o.Enum1 != EnumStoredAsString.Large) + ); + } + + [Test] + public void CoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum1 ?? EnumStoredAsString.Large)) + ); + } + + [Test] + public void DoubleCoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + db.Users.Where(o => ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1)) + ); + } + + [Test] + public void ConditionalTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1)) + ); + } + + [Test] + public void DoubleConditionalTest() + { + AssertResults( + new Dictionary> + { + {"0", o => o is PersistentEnumType}, + {"2", o => o is EnumStoredAsStringType}, + {"Small", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.Enum2 != EnumStoredAsInt32.Unspecified + ? (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Enum2 != EnumStoredAsInt32.Unspecified + ? EnumStoredAsString.Small + : (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1))) + ); + } + + [Test] + public void CoalesceMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NotMappedUser ?? o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o ?? o.NotMappedUser).Enum1) + ); + } + + [Test] + public void ConditionalMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"\"test\"", o => o is AnsiStringType}, + }, + db.Users.Where(o => (o.Name == "test" ? o.NotMappedUser : o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Name == "test" ? o : o.NotMappedUser).Enum1) + ); + } + + [Test] + public void DynamicMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, + db.DynamicUsers.Where("Properties.Name == @0", "test"), + db.DynamicUsers.Where("@0 == Properties.Name", "test") + ); + } + + [Test] + public void DynamicDictionaryMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, +#pragma warning disable CS0252 + db.DynamicUsers.Where(o => o.Settings["Property1"] == "test"), +#pragma warning restore CS0252 +#pragma warning disable CS0253 + db.DynamicUsers.Where(o => "test" == o.Settings["Property1"]) +#pragma warning restore CS0253 + ); + } + + [Test] + public void AssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Component = new UserComponent {Property1 = "prop1"}} + ); + } + + [Test] + public void AssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User + { + Component = new UserComponent {OtherComponent = new UserComponent2 {OtherProperty1 = "other"}} + } + ); + } + + [Test] + public void AnonymousAssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AnonymousAssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {Property1 = "prop1"}} + ); + } + + [Test] + public void AnonymousAssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {OtherComponent = new {OtherProperty1 = "other"}}} + ); + } + + private void AssertResults( + Dictionary> expectedResults, + params IQueryable[] queries) + { + foreach (var query in queries) + { + AssertResult(expectedResults, query); + } + } + + private void AssertResult( + Dictionary> expectedResults, + IQueryable query) + { + AssertResult(expectedResults, QueryMode.Select, query.Expression, query.Expression.Type); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpressionFromAnonymous(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + Expression expression, + System.Type targetType) + { + var result = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(queryMode, Sfi)); + var parameters = ExpressionParameterVisitor.Visit(result); + expression = result.Expression; + var queryModel = NhRelinqQueryParser.Parse(expression); + ParameterTypeLocator.SetParameterTypes(parameters, queryModel, targetType, Sfi); + Assert.That(parameters.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of parameters"); + foreach (var pair in parameters) + { + var origCulture = CultureInfo.CurrentCulture; + try + { + CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; + var expressionText = pair.Key.ToString(); + Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); + Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Type?.Name ?? "null"}"); + } + finally + { + CultureInfo.CurrentCulture = origCulture; + } + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ParameterisedQueries.cs b/src/NHibernate.Test/Linq/ParameterisedQueries.cs index 1d243c619f2..1bafa452ca1 100644 --- a/src/NHibernate.Test/Linq/ParameterisedQueries.cs +++ b/src/NHibernate.Test/Linq/ParameterisedQueries.cs @@ -49,10 +49,10 @@ public void Expressions_Differing_Only_By_Constants_Return_The_Same_Key() var nhNewYork = new NhLinqExpression(newYork.Body, Sfi); Assert.AreEqual(nhLondon.Key, nhNewYork.Key); - Assert.AreEqual(1, nhLondon.ParameterValuesByName.Count); - Assert.AreEqual(1, nhNewYork.ParameterValuesByName.Count); - Assert.AreEqual("London", nhLondon.ParameterValuesByName.First().Value.Item1); - Assert.AreEqual("New York", nhNewYork.ParameterValuesByName.First().Value.Item1); + Assert.AreEqual(1, nhLondon.NamedParameters.Count); + Assert.AreEqual(1, nhNewYork.NamedParameters.Count); + Assert.AreEqual("London", nhLondon.NamedParameters.First().Value.Value); + Assert.AreEqual("New York", nhNewYork.NamedParameters.First().Value.Value); } } @@ -72,13 +72,13 @@ public void CanSpecifyParameterTypeInfo() var nhLondon = new NhLinqExpression(london.Body, Sfi); var nhNewYork = new NhLinqExpression(newYork.Body, Sfi); - var londonParameter = nhLondon.ParameterValuesByName.Single().Value; - Assert.That(londonParameter.Item1, Is.EqualTo("London")); - Assert.That(londonParameter.Item2, Is.EqualTo(NHibernateUtil.StringClob)); + var londonParameter = nhLondon.NamedParameters.Single().Value; + Assert.That(londonParameter.Value, Is.EqualTo("London")); + Assert.That(londonParameter.Type, Is.EqualTo(NHibernateUtil.StringClob)); - var newYorkParameter = nhNewYork.ParameterValuesByName.Single().Value; - Assert.That(newYorkParameter.Item1, Is.EqualTo("New York")); - Assert.That(newYorkParameter.Item2, Is.EqualTo(NHibernateUtil.AnsiString)); + var newYorkParameter = nhNewYork.NamedParameters.Single().Value; + Assert.That(newYorkParameter.Value, Is.EqualTo("New York")); + Assert.That(newYorkParameter.Type, Is.EqualTo(NHibernateUtil.AnsiString)); } } @@ -239,4 +239,4 @@ protected override string[] Mappings get { return Array.Empty(); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/Linq/QueryPlanTests.cs b/src/NHibernate.Test/Linq/QueryPlanTests.cs new file mode 100644 index 00000000000..3eb8fa0eb48 --- /dev/null +++ b/src/NHibernate.Test/Linq/QueryPlanTests.cs @@ -0,0 +1,164 @@ +using System.Linq; +using NHibernate.Dialect; +using NHibernate.Linq; +using NSubstitute; +using NSubstitute.Extensions; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class QueryPlanTests : LinqTestCase + { + [Test] + public void SelectConstantShouldBeCached() + { + ClearQueryPlanCache(); + + var c1 = db.Customers.Select(o => new {o.CustomerId, Constant = "constant"}).First(); + var c2 = db.Customers.Select(o => new {o.CustomerId, Constant = "constant2"}).First(); + var constant = "constant3"; + var c3 = db.Customers.Select(o => new {o.CustomerId, Constant = constant}).First(); + constant = "constant4"; + var c4 = db.Customers.Select(o => new {o.CustomerId, Constant = constant}).First(); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + Assert.That(c1.Constant, Is.EqualTo("constant")); + Assert.That(c2.Constant, Is.EqualTo("constant2")); + Assert.That(c3.Constant, Is.EqualTo("constant3")); + Assert.That(c4.Constant, Is.EqualTo("constant4")); + } + + [Test] + public void GroupByConstantShouldBeCached() + { + ClearQueryPlanCache(); + + var c1 = db.Customers.GroupBy(o => new {o.CustomerId, Constant = "constant"}).Select(o => o.Key).First(); + var c2 = db.Customers.GroupBy(o => new {o.CustomerId, Constant = "constant2"}).Select(o => o.Key).First(); + var constant = "constant3"; + var c3 = db.Customers.GroupBy(o => new {o.CustomerId, Constant = constant}).Select(o => o.Key).First(); + constant = "constant4"; + var c4 = db.Customers.GroupBy(o => new {o.CustomerId, Constant = constant}).Select(o => o.Key).First(); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + Assert.That(c1.Constant, Is.EqualTo("constant")); + Assert.That(c2.Constant, Is.EqualTo("constant2")); + Assert.That(c3.Constant, Is.EqualTo("constant3")); + Assert.That(c4.Constant, Is.EqualTo("constant4")); + } + + [Test] + public void WithLockShouldBeCached() + { + ClearQueryPlanCache(); + // Limit to a few dialects where we know the "nowait" keyword is used to make life easier. + Assume.That(Dialect is MsSql2000Dialect || Dialect is Oracle8iDialect || Dialect is PostgreSQL81Dialect); + + db.Customers.WithLock(LockMode.Upgrade).ToList(); + db.Customers.WithLock(LockMode.UpgradeNoWait).ToList(); + var lockMode = LockMode.None; + db.Customers.WithLock(lockMode).ToList(); + lockMode = LockMode.Read; + db.Customers.WithLock(lockMode).ToList(); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(4)); + } + + [TestCase(true)] + [TestCase(false)] + public void SkipShouldBeCached(bool supportsVariableLimit) + { + if (!Dialect.SupportsLimit || (supportsVariableLimit && !Dialect.SupportsVariableLimit)) + { + Assert.Ignore(); + } + + ClearQueryPlanCache(); + using (var substitute = SubstituteDialect()) + { + substitute.Value.Configure().SupportsVariableLimit.Returns(supportsVariableLimit); + + var c1 = db.Customers.Skip(1).ToList(); + var c2 = db.Customers.Skip(2).ToList(); + var skip = 3; + var c3 = db.Customers.Skip(skip).ToList(); + skip = 4; + var c4 = db.Customers.Skip(skip).ToList(); + + var queryCache = GetQueryPlanCache(); + Assert.That(c1.Count, Is.Not.EqualTo(c2.Count)); + Assert.That(c2.Count, Is.Not.EqualTo(c3.Count)); + Assert.That(c3.Count, Is.Not.EqualTo(c4.Count)); + Assert.That(queryCache.Count, Is.EqualTo(supportsVariableLimit ? 1 : 4)); + } + } + + [TestCase(true)] + [TestCase(false)] + public void TakeShouldBeCached(bool supportsVariableLimit) + { + if (!Dialect.SupportsLimit || (supportsVariableLimit && !Dialect.SupportsVariableLimit)) + { + Assert.Ignore(); + } + + ClearQueryPlanCache(); + using (var substitute = SubstituteDialect()) + { + substitute.Value.Configure().SupportsVariableLimit.Returns(supportsVariableLimit); + + var c1 = db.Customers.Take(1).ToList(); + var c2 = db.Customers.Take(2).ToList(); + var skip = 3; + var c3 = db.Customers.Take(skip).ToList(); + skip = 4; + var c4 = db.Customers.Take(skip).ToList(); + + var queryCache = GetQueryPlanCache(); + Assert.That(c1.Count, Is.EqualTo(1)); + Assert.That(c2.Count, Is.EqualTo(2)); + Assert.That(c3.Count, Is.EqualTo(3)); + Assert.That(c4.Count, Is.EqualTo(4)); + Assert.That(queryCache.Count, Is.EqualTo(supportsVariableLimit ? 1 : 4)); + } + } + + [Test] + public void TrimFunctionShouldNotBeCached() + { + ClearQueryPlanCache(); + + db.Customers.Select(o => new {CustomerId = o.CustomerId.Trim('-')}).First(); + db.Customers.Select(o => new {CustomerId = o.CustomerId.Trim('+')}).First(); + + var queryCache = GetQueryPlanCache(); + Assert.That(queryCache.Count, Is.EqualTo(0)); + } + + [Test] + public void SubstringFunctionShouldBeCached() + { + ClearQueryPlanCache(); + + var queryCache = GetQueryPlanCache(); + var c1 = db.Customers.Select(o => new {Name = o.ContactName.Substring(1)}).First(); + var c2 = db.Customers.Select(o => new {Name = o.ContactName.Substring(2)}).First(); + + Assert.That(c1.Name, Is.Not.EqualTo(c2.Name)); + Assert.That(queryCache.Count, Is.EqualTo(1)); + + ClearQueryPlanCache(); + c1 = db.Customers.Select(o => new { Name = o.ContactName.Substring(1, 2) }).First(); + c2 = db.Customers.Select(o => new { Name = o.ContactName.Substring(2, 1) }).First(); + + Assert.That(c1.Name, Is.Not.EqualTo(c2.Name)); + Assert.That(queryCache.Count, Is.EqualTo(1)); + } + } +} diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 20610d32bad..880ce2e3e69 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -774,7 +774,8 @@ private void AssertResult( var expression = query.Expression; var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi)); - expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap); + expression = preTransformResult.Expression; + var constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); var visitorParameters = new VisitorParameters( diff --git a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs index 5318d771ec2..ee2b82c02f9 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs @@ -71,7 +71,7 @@ public void ShouldCreateDifferentKeys_TypeBinaryExpression() private static string GetCacheKey(Expression exp) { - return ExpressionKeyVisitor.Visit(exp, new Dictionary()); + return ExpressionKeyVisitor.Visit(exp, new Dictionary(), null); } } } diff --git a/src/NHibernate.Test/NHSpecificTest/NH3850/MainFixture.cs b/src/NHibernate.Test/NHSpecificTest/NH3850/MainFixture.cs index 448c83a63dc..6218fdfaf07 100644 --- a/src/NHibernate.Test/NHSpecificTest/NH3850/MainFixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/NH3850/MainFixture.cs @@ -967,7 +967,7 @@ protected override void Max(int? expectedResult) "Non nullable decimal max has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Max(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.Value, - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal max has failed"); } } @@ -1039,7 +1039,7 @@ protected override void Min(int? expectedResult) "Non nullable decimal min has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Min(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.Value, - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal min has failed"); } } @@ -1054,7 +1054,7 @@ public void SingleOrDefaultBBase() var query = session.Query(); Assert.That(() => query.SingleOrDefault(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1087,7 +1087,7 @@ public void SingleOrDefaultCBase() var query = session.Query(); Assert.That(() => query.SingleOrDefault(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1120,7 +1120,7 @@ public void SingleOrDefaultE() var query = session.Query(); Assert.That(() => query.SingleOrDefault(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1183,7 +1183,7 @@ public void SingleOrDefaultGBase() var query = session.Query(); Assert.That(() => query.SingleOrDefault(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1196,7 +1196,7 @@ public void SingleOrDefaultGBaseWithName() var query = session.Query(); Assert.That(() => query.SingleOrDefault(dc => dc.Name == SearchName1), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault(dc => dc.Name == SearchName1)); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1209,7 +1209,7 @@ public void SingleOrDefaultObject() var query = session.Query(); Assert.That(() => query.SingleOrDefault(), Throws.InvalidOperationException); var futureQuery = query.ToFutureValue(qdc => qdc.SingleOrDefault()); - Assert.That(() => futureQuery.Value, Throws.TargetInvocationException.And.InnerException.TypeOf(), "Future"); + Assert.That(() => futureQuery.Value, Throws.InstanceOf(), "Future"); } } @@ -1313,7 +1313,7 @@ private void Sum(int? expectedResult) where DC : DomainClassBase "Non nullable decimal sum has failed"); var futureNonNullableDec = dcQuery.ToFutureValue(qdc => qdc.Sum(dc => dc.NonNullableDecimal)); Assert.That(() => futureNonNullableDec.Value, - Throws.TargetInvocationException.And.InnerException.InstanceOf(), + Throws.InstanceOf(), "Future non nullable decimal sum has failed"); } } diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 8b57857cf85..30b630e90c9 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -56,7 +56,7 @@ - + diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 14fc4ad1460..e44348a9fdb 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -29,6 +29,15 @@ public abstract class TestCase private SchemaExport _schemaExport; private static readonly ILog log = LogManager.GetLogger(typeof(TestCase)); + private static readonly FieldInfo PlanCacheField; + + static TestCase() + { + PlanCacheField = typeof(QueryPlanCache) + .GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException( + "planCache field does not exist in QueryPlanCache."); + } protected Dialect.Dialect Dialect { @@ -489,13 +498,14 @@ protected void AssumeFunctionSupported(string functionName) $"{dialect} doesn't support {functionName} standard function."); } - protected void ClearQueryPlanCache() + protected SoftLimitMRUCache GetQueryPlanCache() { - var planCacheField = typeof(QueryPlanCache) - .GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance) - ?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache."); + return (SoftLimitMRUCache) PlanCacheField.GetValue(Sfi.QueryPlanCache); + } - var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache); + protected void ClearQueryPlanCache() + { + var planCache = GetQueryPlanCache(); planCache.Clear(); } diff --git a/src/NHibernate.Test/TransformTests/ImplementationOfEqualityTests.cs b/src/NHibernate.Test/TransformTests/ImplementationOfEqualityTests.cs index 149f1e05aea..2ca1f61400d 100644 --- a/src/NHibernate.Test/TransformTests/ImplementationOfEqualityTests.cs +++ b/src/NHibernate.Test/TransformTests/ImplementationOfEqualityTests.cs @@ -80,14 +80,14 @@ public void AliasToBeanConstructorResultTransformer_ShouldHaveEqualityBasedOnCto [Test] public void LinqResultTransformer_ShouldHaveEqualityBasedOnCtorParameter() { - Func d1 = x => new object(); - Func, IEnumerable> d2 = x => x; + Func d1 = (x, p) => new object(); + Func, object[], IEnumerable> d2 = (x, p) => x; var transformer1 = new ResultTransformer(d1, d2); var transformer2 = new ResultTransformer(d1, d2); Assert.That(transformer1, Is.EqualTo(transformer2)); Assert.That(transformer1.GetHashCode(), Is.EqualTo(transformer2.GetHashCode())); - Func, IEnumerable> d3 = x => new [] { 1, 2, 3 }; + Func, object[], IEnumerable> d3 = (x, p) => new [] { 1, 2, 3 }; var transformer3 = new ResultTransformer(d1, d3); Assert.That(transformer1, Is.Not.EqualTo(transformer3)); Assert.That(transformer1.GetHashCode(), Is.Not.EqualTo(transformer3.GetHashCode())); diff --git a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs index 4d0344a27eb..a1a27899166 100644 --- a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs @@ -21,13 +21,10 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { - public partial interface INhQueryProvider : IQueryProvider - { - Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, CancellationToken cancellationToken); - } public partial class DefaultQueryProvider : INhQueryProvider, IQueryProviderWithOptions, ISupportFutureBatchNhQueryProvider { @@ -37,7 +34,7 @@ public virtual async Task> ExecuteListAsync(Expression e { cancellationToken.ThrowIfCancellationRequested(); var linqExpression = PrepareQuery(expression, out var query); - var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostExecuteTransformer; + var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostResultTransformer; if (resultTransformer == null) { return await (query.ListAsync(cancellationToken)).ConfigureAwait(false); @@ -45,7 +42,7 @@ public virtual async Task> ExecuteListAsync(Expression e return new List { - (TResult) resultTransformer.DynamicInvoke((await (query.ListAsync(cancellationToken)).ConfigureAwait(false)).AsQueryable()) + (TResult) resultTransformer.Transform((await (query.ListAsync(cancellationToken)).ConfigureAwait(false)).AsQueryable()) }; } @@ -56,16 +53,9 @@ protected virtual async Task ExecuteQueryAsync(NhLinqExpression nhLinqEx cancellationToken.ThrowIfCancellationRequested(); IList results = await (query.ListAsync(cancellationToken)).ConfigureAwait(false); - if (nhQuery.ExpressionToHqlTranslationResults?.PostExecuteTransformer != null) + if (nhQuery.ExpressionToHqlTranslationResults?.PostResultTransformer != null) { - try - { - return nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer.DynamicInvoke(results.AsQueryable()); - } - catch (TargetInvocationException e) - { - throw ReflectHelper.UnwrapTargetInvocationException(e); - } + return nhQuery.ExpressionToHqlTranslationResults.PostResultTransformer.Transform(results.AsQueryable()); } if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence) @@ -103,7 +93,8 @@ public Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + nhLinqExpression.Prepare(); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdateAsync(cancellationToken); } diff --git a/src/NHibernate/Async/Linq/INhQueryProvider.cs b/src/NHibernate/Async/Linq/INhQueryProvider.cs new file mode 100644 index 00000000000..d7b9d8a9b93 --- /dev/null +++ b/src/NHibernate/Async/Linq/INhQueryProvider.cs @@ -0,0 +1,27 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Param; +using NHibernate.Type; + +namespace NHibernate.Linq +{ + public partial interface INhQueryProvider : IQueryProvider + { + + Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, CancellationToken cancellationToken); + } +} diff --git a/src/NHibernate/Driver/OdbcDriver.cs b/src/NHibernate/Driver/OdbcDriver.cs index cf8df041cea..5ac80336849 100644 --- a/src/NHibernate/Driver/OdbcDriver.cs +++ b/src/NHibernate/Driver/OdbcDriver.cs @@ -78,10 +78,18 @@ private void SetVariableLengthParameterSize(DbParameter dbParam, SqlType sqlType { switch (dbParam.DbType) { - case DbType.AnsiString: + case DbType.StringFixedLength: case DbType.AnsiStringFixedLength: + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } + + break; case DbType.String: - case DbType.StringFixedLength: + case DbType.AnsiString: // NH-4083: do not limit to column length if above 2000. Setting size may trigger conversion from // nvarchar to ntext when size is superior or equal to 2000, causing some queries to fail: // https://stackoverflow.com/q/8569844/1178314 diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index 682d755472a..1002005395d 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -161,7 +161,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq { case DbType.AnsiString: case DbType.AnsiStringFixedLength: - dbParam.Size = IsAnsiText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForAnsiClob : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; + dbParam.Size = IsAnsiText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForAnsiClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; break; case DbType.Binary: dbParam.Size = IsBlob(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForBlob : MsSql2000Dialect.MaxSizeForLengthLimitedBinary; @@ -174,7 +176,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq break; case DbType.String: case DbType.StringFixedLength: - dbParam.Size = IsText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForClob : MsSql2000Dialect.MaxSizeForLengthLimitedString; + dbParam.Size = IsText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedString; break; case DbType.DateTime2: dbParam.Size = MsSql2000Dialect.MaxDateTime2; @@ -283,6 +287,18 @@ protected static bool IsBlob(DbParameter dbParam, SqlType sqlType) return (sqlType is BinaryBlobSqlType) || ((DbType.Binary == dbParam.DbType) && sqlType.LengthDefined && (sqlType.Length > MsSql2000Dialect.MaxSizeForLengthLimitedBinary)); } + /// + /// Interprets if a parameter is a character (for the purposes of setting its default size) + /// + /// The parameter + /// The of the parameter + /// True, if the parameter should be interpreted as a character, otherwise False + protected static bool IsChar(DbParameter dbParam, SqlType sqlType) + { + return (DbType.StringFixedLength == dbParam.DbType || DbType.AnsiStringFixedLength == dbParam.DbType) && + sqlType.LengthDefined && sqlType.Length == 1; + } + public override IResultSetsCommand GetResultSetsCommand(ISessionImplementor session) { return new BasicResultSetsCommand(session); diff --git a/src/NHibernate/Driver/SqlServerCeDriver.cs b/src/NHibernate/Driver/SqlServerCeDriver.cs index eb4f03316ea..0b6a4ad93bc 100644 --- a/src/NHibernate/Driver/SqlServerCeDriver.cs +++ b/src/NHibernate/Driver/SqlServerCeDriver.cs @@ -75,6 +75,12 @@ public override IResultSetsCommand GetResultSetsCommand(Engine.ISessionImplement protected override void InitializeParameter(DbParameter dbParam, string name, SqlType sqlType) { base.InitializeParameter(dbParam, name, AdjustSqlType(sqlType)); + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.LengthDefined && sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } AdjustDbParamTypeForLargeObjects(dbParam, sqlType); } diff --git a/src/NHibernate/Engine/QueryParameters.cs b/src/NHibernate/Engine/QueryParameters.cs index 517e4770607..d6611da35fd 100644 --- a/src/NHibernate/Engine/QueryParameters.cs +++ b/src/NHibernate/Engine/QueryParameters.cs @@ -31,7 +31,7 @@ public QueryParameters(IType[] positionalParameterTypes, object[] postionalParam } public QueryParameters(IType[] positionalParameterTypes, object[] postionalParameterValues) - : this(positionalParameterTypes, postionalParameterValues, null, null, false, false, false, null, null, false, null) {} + : this(positionalParameterTypes, postionalParameterValues, null, null, null, false, false, false, null, null, null, null) {} public QueryParameters(IType[] positionalParameterTypes, object[] postionalParameterValues, object[] collectionKeys) : this(positionalParameterTypes, postionalParameterValues, null, collectionKeys) {} @@ -39,6 +39,8 @@ public QueryParameters(IType[] positionalParameterTypes, object[] postionalParam public QueryParameters(IType[] positionalParameterTypes, object[] postionalParameterValues, IDictionary namedParameters, object[] collectionKeys) : this(positionalParameterTypes, postionalParameterValues, namedParameters, null, null, false, false, false, null, null, collectionKeys, null) {} + // Since v5.3 + [Obsolete("This constructor has no more usage in NHibernate and will be removed in a future version.")] public QueryParameters(IType[] positionalParameterTypes, object[] positionalParameterValues, IDictionary lockModes, RowSelection rowSelection, bool isReadOnlyInitialized, bool readOnly, bool cacheable, string cacheRegion, string comment, bool isLookupByNaturalKey, IResultTransformer transformer) : this(positionalParameterTypes, positionalParameterValues, null, lockModes, rowSelection, isReadOnlyInitialized, readOnly, cacheable, cacheRegion, comment, null, transformer) diff --git a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs index 7b9e937bd18..e7b95eed2cb 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Linq; using NHibernate.Util; namespace NHibernate.Hql.Ast.ANTLR @@ -16,15 +17,24 @@ public class ASTQueryTranslatorFactory : IQueryTranslatorFactory { public IQueryTranslator[] CreateQueryTranslators(IQueryExpression queryExpression, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) { - return CreateQueryTranslators(queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); + return CreateQueryTranslators(queryExpression, queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); } - static IQueryTranslator[] CreateQueryTranslators(IASTNode ast, string queryIdentifier, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) + static IQueryTranslator[] CreateQueryTranslators( + IQueryExpression queryExpression, + IASTNode ast, + string queryIdentifier, + string collectionRole, + bool shallow, + IDictionary filters, + ISessionFactoryImplementor factory) { var polymorphicParsers = AstPolymorphicProcessor.Process(ast, factory); var translators = polymorphicParsers - .ToArray(hql => new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); + .ToArray(hql => queryExpression is NhLinqExpression linqExpression + ? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.NamedParameters) + : new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); foreach (var translator in translators) { diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs index 37e5eaffc6e..d6f3a7f0861 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs @@ -36,7 +36,7 @@ public partial class HqlSqlWalker private string _statementTypeName; private int _positionalParameterCount; private int _parameterCount; - private readonly NullableDictionary _namedParameters = new NullableDictionary(); + private readonly NullableDictionary _namedParameterLocations = new NullableDictionary(); private readonly List _parameters = new List(); private FromClause _currentFromClause; private SelectClause _selectClause; @@ -54,6 +54,7 @@ public partial class HqlSqlWalker private readonly LiteralProcessor _literalProcessor; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private JoinType _impliedJoinType; @@ -64,17 +65,30 @@ public partial class HqlSqlWalker private int numberOfParametersInSetClause; private Stack clauseStack=new Stack(); - public HqlSqlWalker(QueryTranslatorImpl qti, - ISessionFactoryImplementor sfi, - ITreeNodeStream input, - IDictionary tokenReplacements, - string collectionRole) + public HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + string collectionRole) + : this(qti, sfi, input, tokenReplacements, null, collectionRole) + { + } + + internal HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) : this(input) { _sessionFactoryHelper = new SessionFactoryHelperExtensions(sfi); _qti = qti; _literalProcessor = new LiteralProcessor(this); _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionFilterRole = collectionRole; } @@ -122,7 +136,7 @@ public ISet QuerySpaces public IDictionary NamedParameters { - get { return _namedParameters; } + get { return _namedParameterLocations; } } internal SessionFactoryHelperExtensions SessionFactoryHelper @@ -1033,13 +1047,20 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode) ); parameter.HqlParameterSpecification = paramSpec; + if (_namedParameters != null && _namedParameters.TryGetValue(name, out var namedParameter)) + { + // Add the parameter type information so that we are able to calculate functions return types + // when the parameter is used as an argument. + parameter.ExpectedType = namedParameter.Type; + } + _parameters.Add(paramSpec); return parameter; } IASTNode GeneratePositionalParameter(IASTNode inputNode) { - if (_namedParameters.Count > 0) + if (_namedParameterLocations.Count > 0) { // NH TODO: remove this limitation throw new SemanticException("cannot define positional parameter after any named parameters have been defined"); @@ -1171,15 +1192,15 @@ public void AddQuerySpaces(string[] spaces) private void TrackNamedParameterPositions(string name) { int loc = _parameterCount++; - object o = _namedParameters[name]; + object o = _namedParameterLocations[name]; if ( o == null ) { - _namedParameters.Add(name, loc); + _namedParameterLocations.Add(name, loc); } else if (o is int) { List list = new List(4) {(int) o, loc}; - _namedParameters[name] = list; + _namedParameterLocations[name] = list; } else { diff --git a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs index 2e27559d1dd..bcf3dc14e11 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs @@ -29,7 +29,8 @@ public partial class QueryTranslatorImpl : IFilterTranslator private readonly string _queryIdentifier; private readonly IASTNode _stageOneAst; private readonly ISessionFactoryImplementor _factory; - + private readonly IDictionary _namedParameters; + private bool _shallowQuery; private bool _compiled; private IDictionary _enabledFilters; @@ -47,10 +48,28 @@ public partial class QueryTranslatorImpl : IFilterTranslator /// Currently enabled filters /// The session factory constructing this translator instance. public QueryTranslatorImpl( - string queryIdentifier, - IASTNode parsedQuery, - IDictionary enabledFilters, - ISessionFactoryImplementor factory) + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory) + : this(queryIdentifier, parsedQuery, enabledFilters, factory, null) + { + } + + /// + /// Creates a new AST-based query translator. + /// + /// The query-identifier (used in stats collection) + /// The hql query to translate + /// Currently enabled filters + /// The session factory constructing this translator instance. + /// The named parameters information. + internal QueryTranslatorImpl( + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory, + IDictionary namedParameters) { _queryIdentifier = queryIdentifier; _stageOneAst = parsedQuery; @@ -58,6 +77,7 @@ public QueryTranslatorImpl( _shallowQuery = false; _enabledFilters = enabledFilters; _factory = factory; + _namedParameters = namedParameters; } /// @@ -434,7 +454,7 @@ private static IStatementExecutor BuildAppropriateStatementExecutor(IStatement s private HqlSqlTranslator Analyze(string collectionRole) { - var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, collectionRole); + var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, _namedParameters, collectionRole); translator.Translate(); @@ -548,15 +568,23 @@ internal class HqlSqlTranslator private readonly QueryTranslatorImpl _qti; private readonly ISessionFactoryImplementor _sfi; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private readonly string _collectionRole; private IStatement _resultAst; - public HqlSqlTranslator(IASTNode ast, QueryTranslatorImpl qti, ISessionFactoryImplementor sfi, IDictionary tokenReplacements, string collectionRole) + public HqlSqlTranslator( + IASTNode ast, + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) { _inputAst = ast; _qti = qti; _sfi = sfi; _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionRole = collectionRole; } @@ -576,7 +604,7 @@ public IStatement Translate() var nodes = new BufferedTreeNodeStream(_inputAst); - var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _collectionRole); + var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _namedParameters, _collectionRole); hqlSqlWalker.TreeAdaptor = new HqlSqlWalkerTreeAdaptor(hqlSqlWalker); try diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs index fd91e09fd3a..8b909121987 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs @@ -63,25 +63,27 @@ private IASTNode GetHighOperand() private static void Check(IASTNode check, IASTNode first, IASTNode second) { - var expectedTypeAwareNode = check as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (!(check is IExpectedTypeAwareNode expectedTypeAwareNode) || + expectedTypeAwareNode.ExpectedType != null) { - IType expectedType = null; - var firstNode = first as SqlNode; - if (firstNode != null) - { - expectedType = firstNode.DataType; - } - if (expectedType == null) + return; + } + + IType expectedType = null; + if (first is SqlNode firstNode) + { + expectedType = firstNode.DataType; + } + + if (expectedType == null) + { + if (second is SqlNode secondNode) { - var secondNode = second as SqlNode; - if (secondNode != null) - { - expectedType = secondNode.DataType; - } + expectedType = secondNode.DataType; } - expectedTypeAwareNode.ExpectedType = expectedType; } + + expectedTypeAwareNode.ExpectedType = expectedType; } } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs index 9b706facb49..60ca24ca379 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs @@ -32,32 +32,34 @@ public void Initialize() IType lhType = (lhs is SqlNode) ? ((SqlNode)lhs).DataType : null; IType rhType = (rhs is SqlNode) ? ((SqlNode)rhs).DataType : null; - if (lhs is IExpectedTypeAwareNode && rhType != null) + TrySetExpectedType(lhs, rhType, true); + TrySetExpectedType(rhs, lhType, false); + } + + private void TrySetExpectedType(IASTNode operand, IType otherOperandType, bool leftHandOperand) + { + if (!(operand is IExpectedTypeAwareNode typeAwareNode) || + otherOperandType == null || + typeAwareNode.ExpectedType != null) { - IType expectedType; + return; + } + + IType expectedType = null; - // we have something like : "? [op] rhs" - if (IsDateTimeType(rhType)) + // we have something like : "lhs [op] ?" or "? [op] rhs" + if (IsDateTimeType(otherOperandType)) + { + if (leftHandOperand) { // more specifically : "? [op] datetime" // 1) if the operator is MINUS, the param needs to be of // some datetime type // 2) if the operator is PLUS, the param needs to be of // some numeric type - expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : rhType; + expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : otherOperandType; } - else - { - expectedType = rhType; - } - ((IExpectedTypeAwareNode)lhs).ExpectedType = expectedType; - } - else if (rhs is ParameterNode && lhType != null) - { - IType expectedType = null; - - // we have something like : "lhs [op] ?" - if (IsDateTimeType(lhType)) + else if (Type == HqlSqlWalker.PLUS) { // more specifically : "datetime [op] ?" // 1) if the operator is MINUS, we really cannot determine @@ -65,17 +67,15 @@ public void Initialize() // numeric would be valid // 2) if the operator is PLUS, the param needs to be of // some numeric type - if (Type == HqlSqlWalker.PLUS) - { - expectedType = NHibernateUtil.Double; - } - } - else - { - expectedType = lhType; + expectedType = NHibernateUtil.Double; } - ((IExpectedTypeAwareNode)rhs).ExpectedType = expectedType; } + else + { + expectedType = otherOperandType; + } + + typeAwareNode.ExpectedType = expectedType; } public override IType DataType diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs index bf0560dfc76..cae4b920ec8 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs @@ -65,15 +65,14 @@ public virtual void Initialize() rhsType = lhsType; } - var lshExpectedTypeAwareNode = lhs as IExpectedTypeAwareNode; - if (lshExpectedTypeAwareNode != null) + if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null) { - lshExpectedTypeAwareNode.ExpectedType = rhsType; + lshTypeAwareNode.ExpectedType = rhsType; } - var rshExpectedTypeAwareNode = rhs as IExpectedTypeAwareNode; - if (rshExpectedTypeAwareNode != null) + + if (rhs is IExpectedTypeAwareNode rshTypeAwareNode && rshTypeAwareNode.ExpectedType == null) { - rshExpectedTypeAwareNode.ExpectedType = lhsType; + rshTypeAwareNode.ExpectedType = lhsType; } MutateRowValueConstructorSyntaxesIfNecessary( lhsType, rhsType ); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs index 0ad4e404bda..f0b9856f76a 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs @@ -47,11 +47,12 @@ public override void Initialize() IASTNode inListChild = inList.GetChild(0); while (inListChild != null) { - var expectedTypeAwareNode = inListChild as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (inListChild is IExpectedTypeAwareNode expectedTypeAwareNode && + expectedTypeAwareNode.ExpectedType == null) { expectedTypeAwareNode.ExpectedType = lhsType; } + inListChild = inListChild.NextSibling; } } diff --git a/src/NHibernate/Hql/StringQueryExpression.cs b/src/NHibernate/Hql/StringQueryExpression.cs index 0b6a1b50098..068d8653042 100644 --- a/src/NHibernate/Hql/StringQueryExpression.cs +++ b/src/NHibernate/Hql/StringQueryExpression.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Hql.Ast.ANTLR; @@ -26,6 +27,9 @@ public string Key } public System.Type Type { get { return typeof (object); } } + + // Since v5.3 + [Obsolete("This property has no usages and will be removed in a future version")] public IList ParameterDescriptors { get; private set; } } @@ -36,4 +40,4 @@ public static StringQueryExpression ToQueryExpression(this string queryString) return new StringQueryExpression(queryString); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/IQueryExpression.cs b/src/NHibernate/IQueryExpression.cs index 65bbe121f68..86181b9555b 100755 --- a/src/NHibernate/IQueryExpression.cs +++ b/src/NHibernate/IQueryExpression.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Engine.Query; @@ -16,6 +17,8 @@ public interface IQueryExpression IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter); string Key { get; } System.Type Type { get; } + // Since v5.3 + [Obsolete("This property has no usages and will be removed in a future version")] IList ParameterDescriptors { get; } } } diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index 9ff4c712b0d..ba46b665466 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -142,7 +142,8 @@ protected internal virtual IType DetermineType(int paramPosition, object paramVa protected internal virtual IType DetermineType(int paramPosition, object paramValue) { - IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? GuessType(paramValue); + IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } @@ -154,67 +155,15 @@ protected internal virtual IType DetermineType(string paramName, object paramVal protected internal virtual IType DetermineType(string paramName, object paramValue) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(paramValue); + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } protected internal virtual IType DetermineType(string paramName, System.Type clazz) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(clazz); - return type; - } - - /// - /// Guesses the from the param's value. - /// - /// The object to guess the of. - /// An for the object. - /// - /// Thrown when the param is null because the - /// can't be guess from a null value. - /// - private IType GuessType(object param) - { - if (param == null) - { - throw new ArgumentNullException("param", "The IType can not be guessed for a null value."); - } - - System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); - return GuessType(clazz); - } - - /// - /// Guesses the from the . - /// - /// The to guess the of. - /// An for the . - /// - /// Thrown when the clazz is null because the - /// can't be guess from a null type. - /// - private IType GuessType(System.Type clazz) - { - if (clazz == null) - { - throw new ArgumentNullException("clazz", "The IType can not be guessed for a null value."); - } - - var type = TypeFactory.HeuristicType(clazz); - if (type == null || type is SerializableType) - { - if (session.Factory.TryGetEntityPersister(clazz.FullName) != null) - { - return NHibernateUtil.Entity(clazz); - } - - if (type == null) - { - throw new HibernateException( - "Could not determine a type for class: " + clazz.AssemblyQualifiedName); - } - } - + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(clazz, session.Factory); return type; } @@ -310,7 +259,11 @@ public IQuery SetParameter(int position, T val) { CheckPositionalParameter(position); - return SetParameter(position, val, parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? GuessType(typeof(T))); + return SetParameter( + position, + val, + parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } private void CheckPositionalParameter(int position) @@ -327,7 +280,11 @@ private void CheckPositionalParameter(int position) public IQuery SetParameter(string name, T val) { - return SetParameter(name, val, parameterMetadata.GetNamedParameterExpectedType(name) ?? GuessType(typeof (T))); + return SetParameter( + name, + val, + parameterMetadata.GetNamedParameterExpectedType(name) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } public IQuery SetParameter(string name, object val) @@ -792,7 +749,12 @@ public IQuery SetParameterList(string name, IEnumerable vals) } object firstValue = vals.Cast().FirstOrDefault(); - SetParameterList(name, vals, firstValue == null ? GuessType(vals.GetCollectionElementType()) : DetermineType(name, firstValue)); + SetParameterList( + name, + vals, + firstValue == null + ? ParameterHelper.GuessType(vals.GetCollectionElementType(), session.Factory) + : DetermineType(name, firstValue)); return this; } diff --git a/src/NHibernate/Impl/ExpressionQueryImpl.cs b/src/NHibernate/Impl/ExpressionQueryImpl.cs index 8868e3b7cf2..5ce3bedb902 100644 --- a/src/NHibernate/Impl/ExpressionQueryImpl.cs +++ b/src/NHibernate/Impl/ExpressionQueryImpl.cs @@ -202,8 +202,10 @@ public ExpandedQueryExpression(IQueryExpression queryExpression, IASTNode tree, _tree = tree; Key = key; Type = queryExpression.Type; +#pragma warning disable CS0618 ParameterDescriptors = queryExpression.ParameterDescriptors; - _cacheableExpression = queryExpression as ICacheableQueryExpression; +#pragma warning restore CS0618 + _cacheableExpression = queryExpression as ICacheableQueryExpression; } #region IQueryExpression Members @@ -217,6 +219,8 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter public System.Type Type { get; private set; } + // Since v5.3 + [Obsolete("This property has no usages and will be removed in a future version")] public IList ParameterDescriptors { get; private set; } #endregion diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index c8de5a37a5e..4d0399d4178 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -11,23 +11,10 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { - public partial interface INhQueryProvider : IQueryProvider - { - //Since 5.2 - [Obsolete("Replaced by ISupportFutureBatchNhQueryProvider interface")] - IFutureEnumerable ExecuteFuture(Expression expression); - - //Since 5.2 - [Obsolete("Replaced by ISupportFutureBatchNhQueryProvider interface")] - IFutureValue ExecuteFutureValue(Expression expression); - void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary> parameters); - int ExecuteDml(QueryMode queryMode, Expression expression); - Task ExecuteAsync(Expression expression, CancellationToken cancellationToken); - } - // 6.0 TODO: merge into INhQueryProvider. public interface ISupportFutureBatchNhQueryProvider { @@ -104,7 +91,7 @@ public TResult Execute(Expression expression) public virtual IList ExecuteList(Expression expression) { var linqExpression = PrepareQuery(expression, out var query); - var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostExecuteTransformer; + var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostResultTransformer; if (resultTransformer == null) { return query.List(); @@ -112,7 +99,7 @@ public virtual IList ExecuteList(Expression expression) return new List { - (TResult) resultTransformer.DynamicInvoke(query.List().AsQueryable()) + (TResult) resultTransformer.Transform(query.List().AsQueryable()) }; } @@ -170,10 +157,10 @@ public virtual IFutureValue ExecuteFutureValue(Expression expr [Obsolete] private static void SetupFutureResult(NhLinqExpression nhExpression, IDelayedValue result) { - if (nhExpression.ExpressionToHqlTranslationResults.PostExecuteTransformer == null) + if (nhExpression.ExpressionToHqlTranslationResults.PostResultTransformer == null) return; - result.ExecuteOnEval = nhExpression.ExpressionToHqlTranslationResults.PostExecuteTransformer; + result.ExecuteOnEval = nhExpression.ExpressionToHqlTranslationResults.PostResultTransformer.GetDelegate(); } public async Task ExecuteAsync(Expression expression, CancellationToken cancellationToken) @@ -211,9 +198,10 @@ protected virtual NhLinqExpression PrepareQuery(Expression expression, out IQuer query = Session.CreateFilter(Collection, nhLinqExpression); } - SetParameters(query, nhLinqExpression.ParameterValuesByName); + nhLinqExpression.Prepare(); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); - SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression, nhLinqExpression.ParameterValuesByName); + SetResultTransformerAndExecuteRegisteredDelegates(query, nhLinqExpression, nhLinqExpression.NamedParameters); return nhLinqExpression; } @@ -224,16 +212,9 @@ protected virtual object ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery { IList results = query.List(); - if (nhQuery.ExpressionToHqlTranslationResults?.PostExecuteTransformer != null) + if (nhQuery.ExpressionToHqlTranslationResults?.PostResultTransformer != null) { - try - { - return nhQuery.ExpressionToHqlTranslationResults.PostExecuteTransformer.DynamicInvoke(results.AsQueryable()); - } - catch (TargetInvocationException e) - { - throw ReflectHelper.UnwrapTargetInvocationException(e); - } + return nhQuery.ExpressionToHqlTranslationResults.PostResultTransformer.Transform(results.AsQueryable()); } if (nhLinqExpression.ReturnType == NhLinqExpressionReturnType.Sequence) @@ -252,42 +233,25 @@ protected virtual object ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery #pragma warning restore 618 } - private static void SetParameters(IQuery query, IDictionary> parameters) + private static void SetParameters(IQuery query, IDictionary parameters) { foreach (var parameterName in query.NamedParameters) { - var param = parameters[parameterName]; - - if (param.Item1 == null) + // The parameter type will be taken from the parameter metadata + var parameter = parameters[parameterName]; + if (parameter.IsCollection) { - if (typeof(IEnumerable).IsAssignableFrom(param.Item2.ReturnedClass) && - param.Item2.ReturnedClass != typeof(string)) - { - query.SetParameterList(parameterName, null, param.Item2); - } - else - { - query.SetParameter(parameterName, null, param.Item2); - } + query.SetParameterList(parameter.Name, (IEnumerable) parameter.Value); } else { - if (param.Item1 is IEnumerable && !(param.Item1 is string)) - { - query.SetParameterList(parameterName, (IEnumerable)param.Item1); - } - else if (param.Item2 != null) - { - query.SetParameter(parameterName, param.Item1, param.Item2); - } - else - { - query.SetParameter(parameterName, param.Item1); - } + query.SetParameter(parameter.Name, parameter.Value); } } } + // Since v5.3 + [Obsolete("Use SetResultTransformerAndExecuteRegisteredDelegates method instead")] public virtual void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary> parameters) { if (nhExpression.ExpressionToHqlTranslationResults != null) @@ -301,6 +265,30 @@ public virtual void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLi } } + public virtual void SetResultTransformerAndExecuteRegisteredDelegates(IQuery query, NhLinqExpression nhExpression, IDictionary parameters) + { + if (nhExpression.ExpressionToHqlTranslationResults == null) + { + return; + } + + // For avoiding breaking derived classes, call the obsolete method until it is dropped. +#pragma warning disable CS0618 + var param = parameters.ToDictionary( + o => o.Key, + o => new Tuple(o.Value.Value, o.Value.Type)); + SetResultTransformerAndAdditionalCriteria(query, nhExpression, param); +#pragma warning restore CS0618 + + if (nhExpression.ExpressionToHqlTranslationResults.PreQueryExecuteDelegates?.Count > 0) + { + foreach (var action in nhExpression.ExpressionToHqlTranslationResults.PreQueryExecuteDelegates) + { + action(query, parameters); + } + } + } + public int ExecuteDml(QueryMode queryMode, Expression expression) { if (Collection != null) @@ -310,7 +298,8 @@ public int ExecuteDml(QueryMode queryMode, Expression expression) var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + nhLinqExpression.Prepare(); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdate(); } diff --git a/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs b/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs index 33d76353176..e19f33b2d28 100644 --- a/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs +++ b/src/NHibernate/Linq/ExpressionToHqlTranslationResults.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Hql.Ast; +using NHibernate.Param; using NHibernate.Type; namespace NHibernate.Linq @@ -10,36 +12,104 @@ public class ExpressionToHqlTranslationResults { public HqlTreeNode Statement { get; } public ResultTransformer ResultTransformer { get; } - public Delegate PostExecuteTransformer { get; } + // Since v5.3 + [Obsolete("Use PostResultTransformer method instead")] + public Delegate PostExecuteTransformer => PostResultTransformer?.GetDelegate(); + + public PostResultTransformer PostResultTransformer { get; } + + // Since v5.3 + [Obsolete("Use instead PreQueryExecuteDelegates property instead.")] public List>>> AdditionalCriteria { get; } + public List>> PreQueryExecuteDelegates { get; } + /// /// If execute result type does not match expected final result type (implying a post execute transformer /// will yield expected result type), the intermediate execute type. /// public System.Type ExecuteResultTypeOverride { get; } - public ExpressionToHqlTranslationResults(HqlTreeNode statement, - IList itemTransformers, + public ExpressionToHqlTranslationResults( + HqlTreeNode statement, + IList itemTransformers, IList listTransformers, IList postExecuteTransformers, - List>>> additionalCriteria, + List>> preQueryExecuteDelegates, System.Type executeResultTypeOverride) { Statement = statement; - PostExecuteTransformer = MergeLambdasAndCompile(postExecuteTransformers); - - var itemTransformer = MergeLambdasAndCompile>(itemTransformers); - var listTransformer = MergeLambdasAndCompile, object>>(listTransformers); + var postResultTransformer = MergeLambdasAndCompile(postExecuteTransformers); + if (postResultTransformer != null) + { + PostResultTransformer = new PostResultTransformer(postResultTransformer); + } + var listTransformer = MergeLambdasAndCompile, object[], object>>(listTransformers); + var itemTransformer = MergeLambdasAndCompile>(itemTransformers); if (itemTransformer != null || listTransformer != null) { - ResultTransformer = new ResultTransformer(itemTransformer, listTransformer); + ResultTransformer = new ResultTransformer(itemTransformer, listTransformer); } + PreQueryExecuteDelegates = preQueryExecuteDelegates; + ExecuteResultTypeOverride = executeResultTypeOverride; +#pragma warning disable 618 + AdditionalCriteria = new List>>>(); +#pragma warning restore 618 + } + + // Since v5.3 + [Obsolete("Use overload with preQueryExecuteDelegates parameter.")] + public ExpressionToHqlTranslationResults( + HqlTreeNode statement, + IList itemTransformers, + IList listTransformers, + IList postExecuteTransformers, + List>>> additionalCriteria, + System.Type executeResultTypeOverride) + : this( + statement, + itemTransformers, + listTransformers, + postExecuteTransformers, + (List>>) null, + executeResultTypeOverride) + { AdditionalCriteria = additionalCriteria; + } + + private ExpressionToHqlTranslationResults( + HqlTreeNode statement, + ResultTransformer resultTransformer, + PostResultTransformer postResultTransformer, + System.Type executeResultTypeOverride, + List>>> additionalCriteria, // TODO 6.0: Remove + List>> preQueryExecuteDelegates) + { + Statement = statement; + ResultTransformer = resultTransformer; + PostResultTransformer = postResultTransformer; ExecuteResultTypeOverride = executeResultTypeOverride; + PreQueryExecuteDelegates = preQueryExecuteDelegates; +#pragma warning disable 618 + AdditionalCriteria = additionalCriteria; +#pragma warning restore 618 + } + + internal ExpressionToHqlTranslationResults WithParameterValues(object[] parameterValues) + { + return new ExpressionToHqlTranslationResults( + Statement, + ResultTransformer?.WithParameterValues(parameterValues), + PostResultTransformer?.WithParameterValues(parameterValues), + ExecuteResultTypeOverride, +#pragma warning disable 618 + AdditionalCriteria, +#pragma warning restore 618 + PreQueryExecuteDelegates + ); } private static TDelegate MergeLambdasAndCompile(IList itemTransformers) @@ -55,13 +125,25 @@ private static TDelegate MergeLambdasAndCompile(IList(body, lambda.Parameters).Compile(); } - private static Delegate MergeLambdasAndCompile(IList transformations) + private static Func MergeLambdasAndCompile(IList transformations) { var lambda = MergeLambdas(transformations); if (lambda == null) return null; - - return lambda.Compile(); + + // Convert from Func object[], TResult> to Func + // in order to avoid using DynamicInvoke which wraps exceptions in a TargetInvocationException. + var inputListParameter = Expression.Parameter(typeof(object), "result"); + var parameterValuesParameter = Expression.Parameter(typeof(object[]), "parameterValues"); + var invoked = Expression.Convert( + Expression.Invoke( + lambda, + Expression.Convert(inputListParameter, lambda.Parameters[0].Type), + parameterValuesParameter), + typeof(object)); + + return Expression.Lambda>( + invoked, inputListParameter, parameterValuesParameter).Compile(); } private static LambdaExpression MergeLambdas(IList transformations) @@ -70,15 +152,26 @@ private static LambdaExpression MergeLambdas(IList transformat return null; var lambda = transformations[0]; - - for (int i = 1; i < transformations.Count; i++) + ParameterExpression parameter; + if (lambda.Parameters.Count < 2) { - var invoked = Expression.Invoke(transformations[i], lambda.Body); + parameter = Expression.Parameter(typeof(object[]), "parameterValues"); + lambda = Expression.Lambda(lambda.Body, lambda.Parameters.Concat(new []{parameter})); + } + else + { + parameter = lambda.Parameters[1]; + } + for (var i = 1; i < transformations.Count; i++) + { + var invoked = transformations[i].Parameters.Count == 2 + ? Expression.Invoke(transformations[i], lambda.Body, parameter) + : Expression.Invoke(transformations[i], lambda.Body); lambda = Expression.Lambda(invoked, lambda.Parameters); } return lambda; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/INhQueryProvider.cs b/src/NHibernate/Linq/INhQueryProvider.cs new file mode 100644 index 00000000000..5b9bd15c222 --- /dev/null +++ b/src/NHibernate/Linq/INhQueryProvider.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Param; +using NHibernate.Type; + +namespace NHibernate.Linq +{ + public partial interface INhQueryProvider : IQueryProvider + { + //Since 5.2 + [Obsolete("Replaced by ISupportFutureBatchNhQueryProvider interface")] + IFutureEnumerable ExecuteFuture(Expression expression); + + //Since 5.2 + [Obsolete("Replaced by ISupportFutureBatchNhQueryProvider interface")] + IFutureValue ExecuteFutureValue(Expression expression); + //Since v5.3 + [Obsolete("Use SetResultTransformerAndExecuteRegisteredDelegates extension method instead.")] + void SetResultTransformerAndAdditionalCriteria(IQuery query, NhLinqExpression nhExpression, IDictionary> parameters); + + int ExecuteDml(QueryMode queryMode, Expression expression); + Task ExecuteAsync(Expression expression, CancellationToken cancellationToken); + } + + // TODO 6.0 Move to INhQueryProvider + public static class NhQueryProviderExtensions + { + public static void SetResultTransformerAndExecuteRegisteredDelegates( + this INhQueryProvider nhQueryProvider, + IQuery query, + NhLinqExpression nhExpression, + IDictionary parameters) + { + if (nhQueryProvider is DefaultQueryProvider defaultQueryProvider) + { + defaultQueryProvider.SetResultTransformerAndExecuteRegisteredDelegates(query, nhExpression, parameters); + return; + } + + var param = parameters.ToDictionary( + o => o.Key, + o => new Tuple(o.Value.Value, o.Value.Type)); +#pragma warning disable CS0618 + nhQueryProvider.SetResultTransformerAndAdditionalCriteria(query, nhExpression, param); +#pragma warning restore CS0618 + } + } +} diff --git a/src/NHibernate/Linq/IntermediateHqlTree.cs b/src/NHibernate/Linq/IntermediateHqlTree.cs index e9fa4bba70e..84d445ef90f 100644 --- a/src/NHibernate/Linq/IntermediateHqlTree.cs +++ b/src/NHibernate/Linq/IntermediateHqlTree.cs @@ -5,6 +5,7 @@ using System.Linq.Expressions; using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR; +using NHibernate.Param; using NHibernate.Transform; using NHibernate.Type; @@ -20,7 +21,7 @@ public class IntermediateHqlTree * We ***shouldn't*** change the behavior of the query just because we are translating it in SQL. */ private readonly bool _isRoot; - private readonly List>>> _additionalCriteria = new List>>>(); + private readonly List>> _preQueryExecuteDelegates = new List>>(); private readonly List _listTransformers = new List(); private readonly List _itemTransformers = new List(); private readonly List _postExecuteTransformers = new List(); @@ -101,7 +102,7 @@ public ExpressionToHqlTranslationResults GetTranslation() _itemTransformers, _listTransformers, _postExecuteTransformers, - _additionalCriteria, + _preQueryExecuteDelegates, ExecuteResultTypeOverride); } @@ -109,8 +110,8 @@ public void AddDistinctRootOperator() { if (!_hasDistinctRootOperator) { - Expression, IList>> x = - l => DistinctRootEntityResultTransformer.TransformList(l); + Expression, object[], IList>> x = + (l, p) => DistinctRootEntityResultTransformer.TransformList(l); _listTransformers.Add(x); _hasDistinctRootOperator = true; @@ -271,9 +272,20 @@ public void AddSet(HqlEquality equality) } } + // Since v5.3 + [Obsolete("Use AddPreQueryExecuteDelegate method instead")] public void AddAdditionalCriteria(Action>> criteria) { - _additionalCriteria.Add(criteria); + _preQueryExecuteDelegates.Add(Action); + + void Action(IQuery q, IDictionary p) => criteria( + q, + p.ToDictionary(o => o.Key, o => new Tuple(o.Value.Value, o.Value.Type))); + } + + public void AddPreQueryExecuteDelegate(Action> action) + { + _preQueryExecuteDelegates.Add(action); } public void AddPostExecuteTransformer(LambdaExpression lambda) diff --git a/src/NHibernate/Linq/LockResultOperator.cs b/src/NHibernate/Linq/LockResultOperator.cs index 2fc841258d4..dad233fa84e 100644 --- a/src/NHibernate/Linq/LockResultOperator.cs +++ b/src/NHibernate/Linq/LockResultOperator.cs @@ -13,7 +13,7 @@ internal class LockResultOperator : ResultOperatorBase public IQuerySource QuerySource => _qsrExpression.ReferencedQuerySource; - public ConstantExpression LockMode { get; } + public ConstantExpression LockMode { get; private set; } public LockResultOperator(QuerySourceReferenceExpression qsrExpression, ConstantExpression lockMode) { @@ -39,6 +39,7 @@ public override ResultOperatorBase Clone(CloneContext cloneContext) public override void TransformExpressions(Func transformation) { _qsrExpression = (QuerySourceReferenceExpression) transformation(_qsrExpression); + LockMode = (ConstantExpression) transformation(LockMode); } } } diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index cf55ac05a7e..422b498e351 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -27,7 +27,7 @@ static class NestedSelectRewriter private static readonly PropertyInfo IGroupingKeyProperty = (PropertyInfo) ReflectHelper.GetProperty, Tuple>(g => g.Key); - public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory) + public static void ReWrite(QueryModel queryModel, VisitorParameters parameters, ISessionFactory sessionFactory) { var nsqmv = new NestedSelectDetector(sessionFactory); nsqmv.Visit(queryModel.SelectClause.Selector); @@ -60,9 +60,11 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory elementExpression.AddRange(expressions); - var keySelector = CreateSelector(elementExpression, 0); + var parameter = Expression.Parameter(typeof(object[]), "parameterValues"); - var elementSelector = CreateSelector(elementExpression, 1); + var keySelector = CreateSelector(elementExpression, 0, parameters, parameter); + + var elementSelector = CreateSelector(elementExpression, 1, parameters, parameter); var input = Expression.Parameter(typeof (IEnumerable), "input"); @@ -71,7 +73,8 @@ public static void ReWrite(QueryModel queryModel, ISessionFactory sessionFactory Expression.Call(CastMethod, input), keySelector, elementSelector), - input); + input, + parameter); queryModel.ResultOperators.Add(new ClientSideSelect2(lambda)); queryModel.ResultOperators.Add(new ClientSideSelect(Expression.Lambda(resultSelector, @group))); @@ -238,7 +241,11 @@ private static Expression GetIdentifier(ISessionFactory sessionFactory, Expressi return ConvertToObject(Expression.PropertyOrField(expression, propertyName)); } - private static LambdaExpression CreateSelector(IEnumerable expressions, int tuple) + private static LambdaExpression CreateSelector( + IEnumerable expressions, + int tuple, + VisitorParameters parameters, + ParameterExpression parameterValuesParameter) { var parameter = Expression.Parameter(typeof (object[]), "x"); @@ -246,7 +253,10 @@ private static LambdaExpression CreateSelector(IEnumerable exp .Where(x => x.Tuple == tuple) .Select(x => ArrayIndex(parameter, x.index)); - var newArrayInit = Expression.NewArrayInit(typeof (object), initializers); + var newArrayInit = ConstantParametersRewriter.Rewrite( + Expression.NewArrayInit(typeof(object), initializers), + parameters, + parameterValuesParameter); return Expression.Lambda( Expression.New(Tuple.Constructor, newArrayInit), diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index 817bfe459e2..aedd02ee848 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -24,16 +24,24 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression /// protected virtual System.Type TargetType => Type; + // Since v5.3 + [Obsolete("This property has no usages and will be removed in a future version")] public IList ParameterDescriptors { get; private set; } public NhLinqExpressionReturnType ReturnType { get; } + // Since v5.3 + [Obsolete("Use NamedParameters property instead.")] public IDictionary> ParameterValuesByName { get; } public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; private set; } protected virtual QueryMode QueryMode { get; } + public IDictionary NamedParameters { get; } + + internal object[] ParameterValues { get; } + private readonly Expression _expression; private readonly IDictionary _constantToParameterMap; @@ -56,12 +64,28 @@ internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFa // referenced from the main query. LinqLogging.LogExpression("Expression (partially evaluated)", _expression); - _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap); - - ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name, - p => System.Tuple.Create(p.Value, p.Type)); + _constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); - Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); + var parameterValuesByName = new Dictionary>(); + var namedParameters = new Dictionary(); + var parameterValues = new object[_constantToParameterMap.Count]; + foreach (var pair in _constantToParameterMap) + { + var parameter = pair.Value; + if (!parameterValuesByName.ContainsKey(parameter.Name)) + { + parameterValuesByName.Add(parameter.Name, System.Tuple.Create(parameter.Value, parameter.Type)); + namedParameters.Add(parameter.Name, parameter); + } + + parameterValues[parameter.Index] = parameter.Value; + } +#pragma warning disable 618 + ParameterValuesByName = parameterValuesByName; +#pragma warning restore 618 + NamedParameters = namedParameters; + ParameterValues = parameterValues; + Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap, sessionFactory); Type = _expression.Type; @@ -87,6 +111,7 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var queryModel = NhRelinqQueryParser.Parse(_expression); + ParameterTypeLocator.SetParameterTypes(_constantToParameterMap, queryModel, TargetType, sessionFactory, true); var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, new QuerySourceNamer(), TargetType, QueryMode); @@ -94,16 +119,10 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter if (ExpressionToHqlTranslationResults.ExecuteResultTypeOverride != null) Type = ExpressionToHqlTranslationResults.ExecuteResultTypeOverride; - +#pragma warning disable CS0618 ParameterDescriptors = requiredHqlParameters.AsReadOnly(); - - CanCachePlan = CanCachePlan && - // If some constants do not have matching HQL parameters, their values from first query will - // be embedded in the plan and reused for subsequent queries: do not cache the plan. - !ParameterValuesByName - .Keys - .Except(requiredHqlParameters.Select(p => p.Name)) - .Any(); +#pragma warning restore CS0618 + CanCachePlan &= visitorParameters.CanCachePlan; // The ast node may be altered by caller, duplicate it for preserving the original one. return DuplicateTree(ExpressionToHqlTranslationResults.Statement.AstNode); @@ -112,11 +131,18 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter internal void CopyExpressionTranslation(NhLinqExpression other) { ExpressionToHqlTranslationResults = other.ExpressionToHqlTranslationResults; +#pragma warning disable CS0618 ParameterDescriptors = other.ParameterDescriptors; +#pragma warning restore CS0618 // Type could have been overridden by translation. Type = other.Type; } + internal void Prepare() + { + ExpressionToHqlTranslationResults = ExpressionToHqlTranslationResults?.WithParameterValues(ParameterValues); + } + private static IASTNode DuplicateTree(IASTNode ast) { var thisNode = ast.DupNode(); diff --git a/src/NHibernate/Linq/PostResultTransformer.cs b/src/NHibernate/Linq/PostResultTransformer.cs new file mode 100644 index 00000000000..6ab473a072a --- /dev/null +++ b/src/NHibernate/Linq/PostResultTransformer.cs @@ -0,0 +1,61 @@ +using System; + +namespace NHibernate.Linq +{ + /// + /// A Linq query transformer that is used to transform the result returned by . + /// + public class PostResultTransformer + { + private readonly Func _transformer; + private readonly object[] _parameterValues; + + internal PostResultTransformer(Func transformer) + : this(transformer, null) + { + } + + private PostResultTransformer(Func transformer, object[] parameterValues) + { + _transformer = transformer ?? throw new ArgumentNullException(nameof(transformer)); + _parameterValues = parameterValues; + } + + internal PostResultTransformer WithParameterValues(object[] parameterValues) + { + return new PostResultTransformer(_transformer, parameterValues); + } + + /// + /// Transform the given query result. + /// + /// The query result to transform. + /// The transformed query result. + public object Transform(object result) + { + return _transformer(result, _parameterValues); + } + + // TODO 6.0: Remove + internal Delegate GetDelegate() + { + return (Func) Func; + + object Func(object l) => Transform(l); + } + + public override int GetHashCode() + { + unchecked + { + return (_transformer.GetHashCode() * 397) ^ (_parameterValues?.GetHashCode() ?? 0); + } + } + + public bool Equals(PostResultTransformer other) + { + return _transformer == other._transformer && + _parameterValues == other._parameterValues; + } + } +} diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 99a8e009571..9826ffdbf06 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -29,8 +29,8 @@ private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel q { _sessionFactory = sessionFactory; var joiner = new Joiner(queryModel, AddJoin); - _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); - _whereJoinDetector = new WhereJoinDetector(this, joiner); + _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner, _sessionFactory); + _whereJoinDetector = new WhereJoinDetector(this, joiner, _sessionFactory); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) diff --git a/src/NHibernate/Linq/ResultTransformer.cs b/src/NHibernate/Linq/ResultTransformer.cs index 20a95f6e104..3c150a0b2cc 100644 --- a/src/NHibernate/Linq/ResultTransformer.cs +++ b/src/NHibernate/Linq/ResultTransformer.cs @@ -9,31 +9,78 @@ namespace NHibernate.Linq [Serializable] public class ResultTransformer : IResultTransformer, IEquatable { - private readonly Func _itemTransformation; - private readonly Func, object> _listTransformation; + private readonly Func _itemTransformation; // TODO 6.0: Remove + private readonly Func, object> _listTransformation; // TODO 6.0: Remove + private readonly Func _itemTransformationParams; // TODO 6.0: Rename to _itemTransformation + private readonly Func, object[], object> _listTransformationParams; // TODO 6.0: Rename to _listTransformation + private readonly object[] _parameterValues; + // Since v5.3 + [Obsolete("Use overload with Func parameter instead.")] public ResultTransformer(Func itemTransformation, Func, object> listTransformation) { _itemTransformation = itemTransformation; _listTransformation = listTransformation; } + public ResultTransformer( + Func itemTransformation, + Func, object[], object> listTransformation) + { + _itemTransformationParams = itemTransformation; + _listTransformationParams = listTransformation; + } + + private ResultTransformer( + Func itemTransformation, + Func, object[], object> listTransformation, + object[] parameterValues, + Func itemTransformationOld, // TODO 6.0: Remove + Func, object> listTransformationOld // TODO 6.0: Remove + ) + { + _itemTransformationParams = itemTransformation; + _listTransformationParams = listTransformation; + _parameterValues = parameterValues; + _itemTransformation = itemTransformationOld; + _listTransformation = listTransformationOld; + } + + internal ResultTransformer WithParameterValues(object[] parameterValues) + { + return new ResultTransformer( + _itemTransformationParams, + _listTransformationParams, + parameterValues, + _itemTransformation, + _listTransformation); + } + #region IResultTransformer Members public object TransformTuple(object[] tuple, string[] aliases) { - return _itemTransformation == null ? tuple : _itemTransformation(tuple); + if (_itemTransformationParams == null && _itemTransformation == null) + { + return tuple; + } + + return _itemTransformationParams != null + ? _itemTransformationParams(tuple, _parameterValues) + : _itemTransformation(tuple); } public IList TransformList(IList collection) { - if (_listTransformation == null) + if (_listTransformation == null && _listTransformationParams == null) { return collection; } var toTransform = GetToTransform(collection); - var transformResult = _listTransformation(toTransform); + var transformResult = _listTransformationParams != null + ? _listTransformationParams(toTransform, _parameterValues) + : _listTransformation(toTransform); var resultList = transformResult as IList; return resultList ?? new List { transformResult }; @@ -64,7 +111,10 @@ public bool Equals(ResultTransformer other) { return true; } - return Equals(other._listTransformation, _listTransformation) && Equals(other._itemTransformation, _itemTransformation); + return Equals(other._listTransformation, _listTransformation) && + Equals(other._itemTransformation, _itemTransformation) && + Equals(other._listTransformationParams, _listTransformationParams) && + Equals(other._itemTransformationParams, _itemTransformationParams); } public override bool Equals(object obj) @@ -78,8 +128,10 @@ public override int GetHashCode() { int lt = (_listTransformation != null ? _listTransformation.GetHashCode() : 0); int it = (_itemTransformation != null ? _itemTransformation.GetHashCode() : 0); - return (lt*397) ^ (it*17); + int lt2 = (_listTransformationParams != null ? _listTransformationParams.GetHashCode() : 0); + int it2 = (_itemTransformationParams != null ? _itemTransformationParams.GetHashCode() : 0); + return (lt*397) ^ (it*17) ^ (lt2 * 397) ^ (it2 * 17); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ConstantParametersRewriter.cs b/src/NHibernate/Linq/Visitors/ConstantParametersRewriter.cs new file mode 100644 index 00000000000..2e57ceb1b9e --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ConstantParametersRewriter.cs @@ -0,0 +1,51 @@ +using System.Linq.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + internal class ConstantParametersRewriter : RelinqExpressionVisitor + { + private readonly VisitorParameters _parameters; + + public ConstantParametersRewriter(VisitorParameters parameters) + { + _parameters = parameters; + Parameter = Expression.Parameter(typeof(object[]), "parameterValues"); + } + + public ConstantParametersRewriter(VisitorParameters parameters, ParameterExpression parameter) + { + _parameters = parameters; + Parameter = parameter; + } + + public ParameterExpression Parameter { get; } + + public static Expression Rewrite(Expression expression, VisitorParameters parameters, out ParameterExpression parameter) + { + var rewriter = new ConstantParametersRewriter(parameters); + expression = rewriter.Visit(expression); + parameter = rewriter.Parameter; + return expression; + } + + public static Expression Rewrite(Expression expression, VisitorParameters parameters, ParameterExpression parameter) + { + var rewriter = new ConstantParametersRewriter(parameters, parameter); + expression = rewriter.Visit(expression); + return expression; + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (_parameters.ConstantToParameterMap.TryGetValue(expression, out var namedParameter)) + { + return Expression.Convert( + Expression.ArrayIndex(Parameter, Expression.Constant(namedParameter.Index)), + expression.Type); + } + + return expression; + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ef4981d2aec..fdce29edd9a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -7,7 +7,10 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Text; +using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -22,22 +25,46 @@ namespace NHibernate.Linq.Visitors public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; + private readonly ISessionFactoryImplementor _sessionFactory; readonly StringBuilder _string = new StringBuilder(); - private ExpressionKeyVisitor(IDictionary constantToParameterMap) + private ExpressionKeyVisitor( + IDictionary constantToParameterMap, + ISessionFactoryImplementor sessionFactory) { _constantToParameterMap = constantToParameterMap; + _sessionFactory = sessionFactory; } + // Since v5.3 + [Obsolete("Use the overload with ISessionFactoryImplementor parameter")] public static string Visit(Expression expression, IDictionary parameters) { - var visitor = new ExpressionKeyVisitor(parameters); + var visitor = new ExpressionKeyVisitor(parameters, null); visitor.Visit(expression); return visitor.ToString(); } + /// + /// Generates the key for the expression. + /// + /// The expression. + /// The session factory. + /// Parameters found in . + /// The key for the expression. + public static string Visit( + Expression rootExpression, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + var visitor = new ExpressionKeyVisitor(parameters, sessionFactory); + visitor.Visit(rootExpression); + + return visitor.ToString(); + } + public override string ToString() { return _string.ToString(); @@ -86,49 +113,70 @@ protected override Expression VisitConstant(ConstantExpression expression) throw new InvalidOperationException("Cannot visit a constant without a constant to parameter map."); if (_constantToParameterMap.TryGetValue(expression, out param)) { - // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. - if (param.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = param.Value as IEnumerable; - if (value != null && !(value is string) && !value.Cast().Any()) - { - _string.Append("EmptyList"); - } - else - { - _string.Append(param.Name); - } - } + VisitParameter(param); } else { - if (expression.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) - { - _string.Append("{"); - _string.Append(String.Join(",", value.Cast())); - _string.Append("}"); - } - else - { - _string.Append(expression.Value); - } - } + VisitConstantValue(expression.Value); } return base.VisitConstant(expression); } + private void VisitConstantValue(object value) + { + if (value == null) + { + _string.Append("NULL"); + return; + } + + if (value is IEnumerable enumerable && !(value is IQueryable)) + { + _string.Append("{"); + _string.Append(string.Join(",", enumerable.Cast())); + _string.Append("}"); + return; + } + + // When MappedAs is used we have to put all sql types information in the key in order to + // distinct when different precisions/sizes are used. + if (_sessionFactory != null && value is IType type) + { + _string.Append(type.Name); + _string.Append('['); + _string.Append(string.Join(",", type.SqlTypes(_sessionFactory).Select(o => o.ToString()))); + _string.Append(']'); + return; + } + + _string.Append(value); + } + + private void VisitParameter(NamedParameter param) + { + // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. + if (param.Value == null) + { + _string.Append("NULL"); + return; + } + + if (param.IsCollection && !((IEnumerable) param.Value).Cast().Any()) + { + _string.Append("EmptyList"); + } + else + { + _string.Append(param.Name); + } + + // Add the type in order to avoid invalid parameter conversions (string -> char) + _string.Append("<"); + _string.Append(param.Value.GetType()); + _string.Append(">"); + } + private T AppendCommas(T expression) where T : Expression { Visit(expression); @@ -159,6 +207,20 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var memberBinder)) + { + Visit(expression.Arguments[1]); + FormatBinder(memberBinder); + return expression; + } + + return base.VisitInvocation(expression); + } +#endif + protected override Expression VisitMethodCall(MethodCallExpression expression) { Visit(expression.Object); @@ -218,8 +280,8 @@ protected override Expression VisitQuerySourceReference(Remotion.Linq.Clauses.Ex protected override Expression VisitDynamic(DynamicExpression expression) { - FormatBinder(expression.Binder); Visit(expression.Arguments, AppendCommas); + FormatBinder(expression.Binder); return expression; } diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 45134248a51..9ee5adedfa5 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -1,9 +1,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; @@ -18,23 +20,24 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly Dictionary _variableParameters = new Dictionary(); + private readonly HashSet _collectionParameters = new HashSet(); private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private static readonly MethodInfo QueryableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), 0); - private static readonly MethodInfo QueryableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), 0); - private static readonly MethodInfo EnumerableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), 0); - private static readonly MethodInfo EnumerableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), 0); + private static readonly ISet PagingMethods = new HashSet + { + ReflectionCache.EnumerableMethods.SkipDefinition, + ReflectionCache.EnumerableMethods.TakeDefinition, + ReflectionCache.QueryableMethods.SkipDefinition, + ReflectionCache.QueryableMethods.TakeDefinition + }; - private readonly ICollection _pagingMethods = new HashSet - { - QueryableSkipDefinition, QueryableTakeDefinition, - EnumerableSkipDefinition, EnumerableTakeDefinition - }; + private static readonly ISet LockMethods = new HashSet + { + ReflectHelper.FastGetMethodDefinition(LinqExtensionMethods.WithLock, default(IQueryable), default(LockMode)), + ReflectHelper.FastGetMethodDefinition(LinqExtensionMethods.WithLock, default(IEnumerable), default(LockMode)) + }; // Since v5.3 [Obsolete("Please use overload with preTransformationResult parameter instead.")] @@ -47,6 +50,7 @@ public ExpressionParameterVisitor(PreTransformationResult preTransformationResul { _sessionFactory = preTransformationResult.SessionFactory; _queryVariables = preTransformationResult.QueryVariables; + _functionRegistry = _sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } // Since v5.3 @@ -59,22 +63,19 @@ public static IDictionary Visit(Expression e return visitor._parameters; } - public static Expression Visit( - PreTransformationResult preTransformationResult, - out IDictionary parameters) + public static IDictionary Visit(PreTransformationResult preTransformationResult) { var visitor = new ExpressionParameterVisitor(preTransformationResult); - var expression = visitor.Visit(preTransformationResult.Expression); - parameters = visitor._parameters; - - return expression; + visitor.Visit(preTransformationResult.Expression); + return visitor._parameters; } protected override Expression VisitMethodCall(MethodCallExpression expression) { - if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) + if (VisitorUtil.IsMappedAs(expression.Method)) { var rawParameter = Visit(expression.Arguments[0]); + // TODO 6.0: Remove below code and return expression as this logic is now inside ConstantTypeLocator var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -95,10 +96,10 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) ? expression.Method.GetGenericMethodDefinition() : expression.Method; - if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) + if ((PagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) || LockMethods.Contains(method)) { - //TODO: find a way to make this code cleaner var query = Visit(expression.Arguments[0]); + //TODO 6.0: Remove the below code and return expression var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -107,6 +108,17 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } + if (_functionRegistry != null && + _functionRegistry.TryGetGenerator(method, out var generator) && + generator is CollectionContainsGenerator) + { + var argument = method.IsStatic ? expression.Arguments[0] : expression.Object; + if (argument is ConstantExpression constantExpression) + { + _collectionParameters.Add(constantExpression); + } + } + if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory)) { return expression; @@ -115,6 +127,20 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return base.VisitMethodCall(expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out _)) + { + // Avoid adding System.Runtime.CompilerServices.CallSite instance as a parameter + base.Visit(expression.Arguments[1]); + return expression; + } + + return base.VisitInvocation(expression); + } +#endif + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) @@ -125,11 +151,14 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); // Constant characters should be sent as strings - if (expression.Type == typeof(char)) + // TODO 6.0: Remove + if (_queryVariables == null && expression.Type == typeof(char)) { value = value.ToString(); } @@ -144,13 +173,13 @@ protected override Expression VisitConstant(ConstantExpression expression) _queryVariables.TryGetValue(expression, out var variable) && !_variableParameters.TryGetValue(variable, out parameter)) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); _variableParameters.Add(variable, parameter); } if (parameter == null) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); } _parameters.Add(expression, parameter); @@ -161,6 +190,15 @@ protected override Expression VisitConstant(ConstantExpression expression) return base.VisitConstant(expression); } + private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) + { + var index = _parameters.Count; + var parameterName = "p" + (_parameters.Count + 1); + return _collectionParameters.Contains(expression) + ? new NamedListParameter(parameterName, value, type, index) + : new NamedParameter(parameterName, value, type, index); + } + private static bool IsNullObject(ConstantExpression expression) { return expression.Type == typeof(Object) && expression.Value == null; diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index cd9cd49eadb..fa8acbf6ebf 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -226,18 +226,14 @@ private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { - //This is an ugly workaround for dynamic expressions. - //Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression - if (expression.Arguments.Count == 2 && - expression.Arguments[0] is ConstantExpression constant && - constant.Value is CallSite site && - site.Binder is GetMemberBinder binder) +#if NETCOREAPP2_0 + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var binder)) { return _hqlTreeBuilder.Dot( VisitExpression(expression.Arguments[1]).AsExpression(), _hqlTreeBuilder.Ident(binder.Name)); } - +#endif return VisitExpression(expression.Expression); } @@ -564,7 +560,9 @@ protected HqlTreeNode VisitMethodCallExpression(MethodCallExpression expression) throw new NotSupportedException(method.ToString()); } - return generator.BuildHql(method, expression.Object, expression.Arguments, _hqlTreeBuilder, this); + return _parameters.UpdateCanCachePlan( + () => generator.BuildHql(method, expression.Object, expression.Arguments, _hqlTreeBuilder, this), + visitor => visitor.Visit(expression)); } protected HqlTreeNode VisitLambdaExpression(LambdaExpression expression) diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 580ba3cf00c..019769fccb1 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -19,16 +19,18 @@ internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private bool _requiresJoinForNonIdentifier; private bool _preventJoinsInConditionalTest; private bool _hasIdentifier; private int _memberExpressionDepth; - public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } protected override Expression VisitMember(MemberExpression expression) @@ -55,7 +57,7 @@ protected override Expression VisitMember(MemberExpression expression) ((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) && _joiner.CanAddJoin(expression)) { - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); return _joiner.AddJoin(result, key); } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs new file mode 100644 index 00000000000..34326640169 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -0,0 +1,321 @@ +using System.Collections.Generic; +using System.Dynamic; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + /// + /// Locates parameter actual type based on its usage. + /// + public static class ParameterTypeLocator + { + /// + /// List of for which the should be related to the other side + /// of a (e.g. o.MyEnum == MyEnum.Option -> MyEnum.Option should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet ValidBinaryExpressionTypes = new HashSet + { + ExpressionType.Equal, + ExpressionType.NotEqual, + ExpressionType.GreaterThanOrEqual, + ExpressionType.GreaterThan, + ExpressionType.LessThan, + ExpressionType.LessThanOrEqual, + ExpressionType.Coalesce, + ExpressionType.Assign + }; + + /// + /// List of for which the should be copied across + /// as related (e.g. (o.MyEnum ?? MyEnum.Option) == MyEnum.Option2 -> MyEnum.Option2 should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet NonVoidOperators = new HashSet + { + ExpressionType.Coalesce, + ExpressionType.Conditional + }; + + /// + /// Set query parameter types based on the given query model. + /// + /// The query parameters. + /// The query model. + /// The target entity type. + /// The session factory. + public static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory) + { + SetParameterTypes(parameters, queryModel, targetType, sessionFactory, false); + } + + internal static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory, + bool removeMappedAsCalls) + { + if (parameters.Count == 0) + { + return; + } + + var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); + queryModel.TransformExpressions(visitor.Visit); + + foreach (var pair in visitor.ConstantExpressions) + { + var type = pair.Value; + var constantExpression = pair.Key; + if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + { + continue; + } + + if (type != null) + { + // MappedAs was used + namedParameter.Type = type; + continue; + } + + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + { + foreach (var memberExpression in memberExpressions) + { + if (ExpressionsHelper.TryGetMappedType( + sessionFactory, + memberExpression, + out type, + out _, + out _, + out _)) + { + break; + } + } + } + + // No related MemberExpressions was found, guess the type by value or its type when null. + if (type == null) + { + type = constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); + } + + namedParameter.Type = type; + } + } + + private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor + { + private readonly bool _removeMappedAsCalls; + private readonly System.Type _targetType; + private readonly IDictionary _parameters; + private readonly ISessionFactoryImplementor _sessionFactory; + public readonly Dictionary ConstantExpressions = + new Dictionary(); + public readonly Dictionary> RelatedExpressions = + new Dictionary>(); + + public ConstantTypeLocatorVisitor( + bool removeMappedAsCalls, + System.Type targetType, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + _removeMappedAsCalls = removeMappedAsCalls; + _targetType = targetType; + _sessionFactory = sessionFactory; + _parameters = parameters; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + node = (BinaryExpression) base.VisitBinary(node); + if (!ValidBinaryExpressionTypes.Contains(node.NodeType)) + { + return node; + } + + var left = Unwrap(node.Left); + var right = Unwrap(node.Right); + if (node.NodeType == ExpressionType.Assign) + { + VisitAssign(left, right); + } + else + { + AddRelatedExpression(node, left, right); + AddRelatedExpression(node, right, left); + } + + return node; + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + node = (ConditionalExpression) base.VisitConditional(node); + var ifTrue = Unwrap(node.IfTrue); + var ifFalse = Unwrap(node.IfFalse); + AddRelatedExpression(node, ifTrue, ifFalse); + AddRelatedExpression(node, ifFalse, ifTrue); + + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (VisitorUtil.IsMappedAs(node.Method)) + { + var rawParameter = Visit(node.Arguments[0]); + var parameter = rawParameter as ConstantExpression; + var type = node.Arguments[1] as ConstantExpression; + if (parameter == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} must be called on an expression which can be evaluated as " + + $"{nameof(ConstantExpression)}. It was call on {rawParameter?.GetType().Name ?? "null"} instead."); + if (type == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} type must be supplied as {nameof(ConstantExpression)}. " + + $"It was {node.Arguments[1]?.GetType().Name ?? "null"} instead."); + + ConstantExpressions[parameter] = (IType) type.Value; + + return _removeMappedAsCalls + ? rawParameter + : node; + } + + return base.VisitMethodCall(node); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) + { + return node; + } + + RelatedExpressions.Add(node, new HashSet()); + ConstantExpressions.Add(node, null); + return node; + } + + public override Expression Visit(Expression node) + { + if (node is SubQueryExpression subQueryExpression) + { + subQueryExpression.QueryModel.TransformExpressions(Visit); + } + + return base.Visit(node); + } + + private void VisitAssign(Expression leftNode, Expression rightNode) + { + // Insert and Update statements have assign expressions, where the left side is a parameter and its name + // represents the property path to be assigned + if (!(leftNode is ParameterExpression parameterExpression) || + !(rightNode is ConstantExpression constantExpression)) + { + return; + } + + var entityName = _sessionFactory.TryGetGuessEntityName(_targetType); + if (entityName == null) + { + return; + } + + var persister = _sessionFactory.GetEntityPersister(entityName); + ConstantExpressions[constantExpression] = persister.EntityMetamodel.GetPropertyType(parameterExpression.Name); + } + + private void AddRelatedExpression(Expression node, Expression left, Expression right) + { + if (left.NodeType == ExpressionType.MemberAccess || + IsDynamicMember(left) || + left is QuerySourceReferenceExpression) + { + AddRelatedExpression(right, left); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, left); + } + } + + // Copy all found MemberExpressions to the other side + // (e.g. (o.Prop ?? constant1) == constant2 -> copy o.Prop to constant2) + if (RelatedExpressions.TryGetValue(left, out var set)) + { + foreach (var nestedMemberExpression in set) + { + AddRelatedExpression(right, nestedMemberExpression); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, nestedMemberExpression); + } + } + } + } + + private void AddRelatedExpression(Expression expression, Expression relatedExpression) + { + if (!RelatedExpressions.TryGetValue(expression, out var set)) + { + set = new HashSet(); + RelatedExpressions.Add(expression, set); + } + + set.Add(relatedExpression); + } + + private bool IsDynamicMember(Expression expression) + { + switch (expression) + { +#if NETCOREAPP2_0 + case InvocationExpression invocationExpression: + // session.Query().Where("Properties.Name == @0", "First Product") + return ExpressionsHelper.TryGetDynamicMemberBinder(invocationExpression, out _); +#endif + case DynamicExpression dynamicExpression: + return dynamicExpression.Binder is GetMemberBinder; + case MethodCallExpression methodCallExpression: + // session.Query() where p.Properties["Name"] == "First Product" select p + return VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(methodCallExpression, out _); + default: + return false; + } + } + + private static Expression Unwrap(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + return unaryExpression.Operand; + } + + return expression; + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 040e9b38932..555ae8b6d8c 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -37,7 +37,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer SubQueryConditionalExpander.ReWrite(queryModel); } - NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); + NestedSelectRewriter.ReWrite(queryModel, parameters, parameters.SessionFactory); // Remove unnecessary body operators RemoveUnnecessaryBodyOperators.ReWrite(queryModel); @@ -200,8 +200,12 @@ private void AddPostExecuteTransformerForCount() var inputList = Expression.Parameter(inputListType, "inputList"); // Sum has no suitable generic overload, throw in Sum on int, then the code using it // will check and adjust it if it is long instead of int (GetAggregateMethodCall does that). - var aggregateCall = GetAggregateMethodCall(ReflectionCache.EnumerableMethods.SumOnInt, inputListType, elementType, inputList); - _hqlTree.AddPostExecuteTransformer(Expression.Lambda(aggregateCall, inputList)); + var aggregateCall = ConstantParametersRewriter.Rewrite( + GetAggregateMethodCall(ReflectionCache.EnumerableMethods.SumOnInt, inputListType, elementType, inputList), + VisitorParameters, + out var parameter); + + _hqlTree.AddPostExecuteTransformer(Expression.Lambda(aggregateCall, inputList, parameter)); } private void AddPostExecuteTransformerForResultAggregate(MethodInfo aggregateMethodTemplate) @@ -215,8 +219,12 @@ private void AddPostExecuteTransformerForResultAggregate(MethodInfo aggregateMet { var inputListType = typeof(IEnumerable<>).MakeGenericType(elementType); var inputList = Expression.Parameter(inputListType, "inputList"); - var aggregateCall = GetAggregateMethodCall(aggregateMethodTemplate, inputListType, elementType, inputList); - aggregateLambda = Expression.Lambda(aggregateCall, inputList); + var aggregateCall = ConstantParametersRewriter.Rewrite( + GetAggregateMethodCall(aggregateMethodTemplate, inputListType, elementType, inputList), + VisitorParameters, + out var parameter); + + aggregateLambda = Expression.Lambda(aggregateCall, inputList, parameter); } else { @@ -235,8 +243,12 @@ private void AddPostExecuteTransformerForResultAggregate(MethodInfo aggregateMet var aggregateCall = GetAggregateMethodCall(aggregateMethodTemplate, nullableInputListType, nullableElementType, nullableInputList); - var convert = Expression.Convert(aggregateCall, elementType); - aggregateLambda = Expression.Lambda(convert, nullableInputList); + var convert = ConstantParametersRewriter.Rewrite( + Expression.Convert(aggregateCall, elementType), + VisitorParameters, + out var parameter); + + aggregateLambda = Expression.Lambda(convert, nullableInputList, parameter); } _hqlTree.AddPostExecuteTransformer(aggregateLambda); @@ -276,7 +288,13 @@ private void AddPostExecuteTransformerForSum() aggregateCall); if (!elementTypeIsNullable) conditionalAggregateCall = Expression.Convert(conditionalAggregateCall, elementType); - _hqlTree.AddPostExecuteTransformer(Expression.Lambda(conditionalAggregateCall, inputList)); + + conditionalAggregateCall = ConstantParametersRewriter.Rewrite( + conditionalAggregateCall, + VisitorParameters, + out var parameter); + + _hqlTree.AddPostExecuteTransformer(Expression.Lambda(conditionalAggregateCall, inputList, parameter)); } private MethodCallExpression GetAggregateMethodCall(MethodInfo aggregateMethodTemplate, System.Type inputListType, diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs index 574533820b7..496b237d7f1 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregate.cs @@ -27,14 +27,14 @@ public void Process(AggregateResultOperator resultOperator, QueryModelVisitor qu // queries, this is not specific to LINQ provider.) var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(inputType), "inputList"); var aggregate = ReflectionCache.EnumerableMethods.AggregateDefinition.MakeGenericMethod(inputType); - MethodCallExpression call = Expression.Call( - aggregate, - inputList, - accumulatorFunc - ); - tree.AddPostExecuteTransformer(Expression.Lambda(call, inputList)); + var call = ConstantParametersRewriter.Rewrite( + Expression.Call(aggregate, inputList, accumulatorFunc), + queryModelVisitor.VisitorParameters, + out var parameter); + + tree.AddPostExecuteTransformer(Expression.Lambda(call, inputList, parameter)); // There is no more a list transformer yielding an IList, but this aggregate case // have inputType = resultType, so no further action is required. } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs index d369066a575..bdb5d113d72 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAggregateFromSeed.cs @@ -22,8 +22,7 @@ public void Process(AggregateFromSeedResultOperator resultOperator, QueryModelVi var accumulatorType = resultOperator.Func.Parameters[0].Type; var inputList = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(inputType), "inputList"); - MethodCallExpression call; - + Expression call; if (resultOperator.OptionalResultSelector == null) { var aggregate = ReflectionCache.EnumerableMethods.AggregateWithSeedDefinition @@ -51,13 +50,15 @@ public void Process(AggregateFromSeedResultOperator resultOperator, QueryModelVi ); } + call = ConstantParametersRewriter.Rewrite(call, queryModelVisitor.VisitorParameters, out var parameter); + // NH-3850: changed from list transformer (working on IEnumerable) to post execute // transformer (working on IEnumerable) for globally aggregating polymorphic results // instead of aggregating results for each class separately and yielding only the first. // If the aggregation relies on ordering, final result will still be wrong due to // polymorphic results being union-ed without re-ordering. (This is a limitation of all polymorphic // queries, this is not specific to LINQ provider.) - tree.AddPostExecuteTransformer(Expression.Lambda(call, inputList)); + tree.AddPostExecuteTransformer(Expression.Lambda(call, inputList, parameter)); // There is no more a list transformer yielding an IList, have to override the execute // result type. tree.ExecuteResultTypeOverride = inputType; diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs index b22237e6f58..0fa7005bdfa 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAll.cs @@ -19,11 +19,11 @@ public void Process(AllResultOperator resultOperator, QueryModelVisitor queryMod { tree.AddTakeClause(tree.TreeBuilder.Constant(1)); - Expression, bool>> x = l => !l.Any(); + Expression, object[], bool>> x = (l, p) => !l.Any(); tree.AddListTransformer(x); // NH-3850: Queries with polymorphism yields many results which must be combined. - Expression, bool>> px = l => l.All(r => r); + Expression, object[], bool>> px = (l, p) => l.All(r => r); tree.AddPostExecuteTransformer(px); } else @@ -32,4 +32,4 @@ public void Process(AllResultOperator resultOperator, QueryModelVisitor queryMod } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAny.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAny.cs index 8712418f94a..4e3a1aa4513 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAny.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAny.cs @@ -15,11 +15,11 @@ public void Process(AnyResultOperator anyOperator, QueryModelVisitor queryModelV { tree.AddTakeClause(tree.TreeBuilder.Constant(1)); - Expression, bool>> x = l => l.Any(); + Expression, object[], bool>> x = (l, p) => l.Any(); tree.AddListTransformer(x); // NH-3850: Queries with polymorphism yields many results which must be combined. - Expression, bool>> px = l => l.Any(r => r); + Expression, object[], bool>> px = (l, p) => l.Any(r => r); tree.AddPostExecuteTransformer(px); } else @@ -28,4 +28,4 @@ public void Process(AnyResultOperator anyOperator, QueryModelVisitor queryModelV } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessClientSideSelect.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessClientSideSelect.cs index 81bf0fc1c7d..f18f3b43646 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessClientSideSelect.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessClientSideSelect.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Linq.GroupBy; using NHibernate.Util; @@ -17,10 +18,12 @@ public void Process(ClientSideSelect resultOperator, QueryModelVisitor queryMode var selectMethod = ReflectionCache.EnumerableMethods.SelectDefinition.MakeGenericMethod(new[] { inputType, outputType }); var toListMethod = ReflectionCache.EnumerableMethods.ToListDefinition.MakeGenericMethod(new[] { outputType }); - var lambda = Expression.Lambda( - Expression.Call(toListMethod, - Expression.Call(selectMethod, inputList, resultOperator.SelectClause)), - inputList); + var argument = ConstantParametersRewriter.Rewrite( + Expression.Call(selectMethod, inputList, resultOperator.SelectClause), + queryModelVisitor.VisitorParameters, + out var parameter); + + var lambda = Expression.Lambda(Expression.Call(toListMethod, argument), inputList, parameter); tree.AddListTransformer(lambda); } @@ -32,4 +35,4 @@ public void Process(ClientSideSelect2 resultOperator, QueryModelVisitor queryMod tree.AddListTransformer(resultOperator.SelectClause); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirstOrSingleBase.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirstOrSingleBase.cs index a7bca0100b9..970e5055550 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirstOrSingleBase.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFirstOrSingleBase.cs @@ -15,12 +15,11 @@ protected static void AddClientSideEval(MethodInfo target, QueryModelVisitor que var parameter = Expression.Parameter(typeof(IQueryable<>).MakeGenericType(type), null); var lambda = Expression.Lambda( - Expression.Call( - target, - parameter), - parameter); + Expression.Call(target, parameter), + parameter, + Expression.Parameter(typeof(object[]), "parameterValues")); tree.AddPostExecuteTransformer(lambda); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs index 5ddbe360b8e..84679996403 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessLock.cs @@ -5,7 +5,7 @@ internal class ProcessLock : IResultOperatorProcessor public void Process(LockResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) { var alias = queryModelVisitor.VisitorParameters.QuerySourceNamer.GetName(resultOperator.QuerySource); - tree.AddAdditionalCriteria((q, p) => q.SetLockMode(alias, (LockMode) resultOperator.LockMode.Value)); + tree.AddPreQueryExecuteDelegate((q, p) => q.SetLockMode(alias, (LockMode) resultOperator.LockMode.Value)); } } } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs index 21c82a87eb0..c59f0e8b414 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs @@ -37,9 +37,12 @@ public void Process(NonAggregatingGroupBy resultOperator, QueryModelVisitor quer var groupByExpr = Expression.Call(groupByMethod, castToItemExpr, keySelectorExpr, elementSelectorExpr); - var toListExpr = Expression.Call(toList, groupByExpr); + var toListExpr = ConstantParametersRewriter.Rewrite( + Expression.Call(toList, groupByExpr), + queryModelVisitor.VisitorParameters, + out var parameter); - var lambdaExpr = Expression.Lambda(toListExpr, listParameter); + var lambdaExpr = Expression.Lambda(toListExpr, listParameter, parameter); tree.AddListTransformer(lambdaExpr); } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessor.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessor.cs index 9ff81003cc1..582dccde27b 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessor.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ResultOperatorProcessor.cs @@ -14,7 +14,9 @@ public ResultOperatorProcessor(IResultOperatorProcessor processor) public override void Process(ResultOperatorBase resultOperator, QueryModelVisitor queryModel, IntermediateHqlTree tree) { - _processor.Process((T)resultOperator, queryModel, tree); + queryModel.VisitorParameters.UpdateCanCachePlan( + () => _processor.Process((T) resultOperator, queryModel, tree), + visitor => resultOperator.TransformExpressions(visitor.Visit)); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs index df1cdfb3daa..886581da12f 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs @@ -15,6 +15,7 @@ public class SelectClauseVisitor : RelinqExpressionVisitor private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private HashSet _hqlNodes; private readonly ParameterExpression _inputParameter; + private readonly ParameterExpression _parameterValuesParameter; private readonly VisitorParameters _parameters; private int _iColumn; private List _hqlTreeNodes = new List(); @@ -23,6 +24,7 @@ public class SelectClauseVisitor : RelinqExpressionVisitor public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters) { _inputParameter = Expression.Parameter(inputType, "input"); + _parameterValuesParameter = Expression.Parameter(typeof(object[]), "parameterValues"); _parameters = parameters; _hqlVisitor = new HqlGeneratorExpressionVisitor(_parameters); } @@ -59,7 +61,7 @@ public void VisitSelector(Expression expression) if ((projection != expression) && !_hqlNodes.Contains(expression)) { - ProjectionExpression = Expression.Lambda(projection, _inputParameter); + ProjectionExpression = Expression.Lambda(projection, _inputParameter, _parameterValuesParameter); } // Handle any boolean results in the output nodes @@ -92,6 +94,18 @@ public override Expression Visit(Expression expression) return base.Visit(expression); } + protected override Expression VisitConstant(ConstantExpression expression) + { + if (_parameters.ConstantToParameterMap.TryGetValue(expression, out var namedParameter)) + { + return Expression.Convert( + Expression.ArrayIndex(_parameterValuesParameter, Expression.Constant(namedParameter.Index)), + expression.Type); + } + + return expression; + } + private static readonly MethodInfo ConvertChangeType = ReflectHelper.FastGetMethod(System.Convert.ChangeType, default(object), default(System.Type)); diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index a4d2a1f2c65..a0501f60afa 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -1,8 +1,11 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Param; +using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors { @@ -23,6 +26,8 @@ public class VisitorParameters public QueryMode RootQueryMode { get; } + internal bool CanCachePlan { get; private set; } = true; + public VisitorParameters( ISessionFactoryImplementor sessionFactory, IDictionary constantToParameterMap, @@ -38,5 +43,84 @@ public VisitorParameters( TargetEntityType = targetEntityType; RootQueryMode = rootQueryMode; } + + internal void UpdateCanCachePlan( + System.Action action, + Action visitAction) + { + UpdateCanCachePlan( + () => + { + action(); + return true; + }, + visitAction); + } + internal T UpdateCanCachePlan( + Func function, + Action visitAction) + { + var totalHqlParameters = RequiredHqlParameters.Count; + var result = function(); + var visitor = new ParameterMatcher(this, totalHqlParameters); + if (!visitor.MatchHqlParameters(visitAction)) + { + CanCachePlan = false; + } + + return result; + } + + private class ParameterMatcher : RelinqExpressionVisitor + { + private readonly VisitorParameters _parameters; + private readonly List _namedParameters = new List(); + private readonly int _totalHqlParameters; + + public ParameterMatcher(VisitorParameters parameters, int totalHqlParameters) + { + _parameters = parameters; + _totalHqlParameters = totalHqlParameters; + } + + public bool MatchHqlParameters(Action visitAction) + { + visitAction(this); + if (_namedParameters.Count == 0 && _parameters.RequiredHqlParameters.Count == _totalHqlParameters) + { + return true; + } + + return MatchHqlParameters(_parameters.RequiredHqlParameters.Skip(_totalHqlParameters).ToList()); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + if (_parameters.ConstantToParameterMap.TryGetValue(node, out var parameter)) + { + _namedParameters.Add(parameter); + } + + return base.VisitConstant(node); + } + + private bool MatchHqlParameters(List hqlParameters) + { + if (_namedParameters.Count != hqlParameters.Count) + { + return false; + } + + for (var i = 0; i < hqlParameters.Count; i++) + { + if (_namedParameters[i].Name != hqlParameters[i].Name) + { + return false; + } + } + + return true; + } + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 22ac89dd0aa..40dbaeb4fb7 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -13,25 +13,12 @@ public static class VisitorUtil { public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) { - memberName = null; - - // A dynamic component must be an IDictionary with a string key. - - if (method.Name != "get_Item" || !typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type)) - return false; - - var key = arguments.First() as ConstantExpression; - if (key == null || key.Type != typeof(string)) - return false; - - // The potential member name - memberName = (string)key.Value; - - // Need the owning member (the dictionary). - var member = targetObject as MemberExpression; - if (member == null) + if (!TryGetPotentialDynamicComponentDictionaryMember(method, targetObject, arguments, out memberName)) + { return false; + } + var member = (MemberExpression) targetObject; var memberPath = member.Member.Name; var metaData = sessionFactory.GetClassMetadata(member.Expression.Type); @@ -131,5 +118,42 @@ public static string GetMemberPath(this MemberExpression memberExpression) } return path; } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember(MethodCallExpression expression, out string memberName) + { + return TryGetPotentialDynamicComponentDictionaryMember( + expression.Method, + expression.Object, + expression.Arguments, + out memberName); + } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember( + MethodInfo method, + Expression targetObject, + IEnumerable arguments, + out string memberName) + { + memberName = null; + // A dynamic component must be an IDictionary with a string key. + if (method.Name != "get_Item" || + targetObject.NodeType != ExpressionType.MemberAccess || // Need the owning member (the dictionary). + !(arguments.First() is ConstantExpression key) || + key.Type != typeof(string) || + (!typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type))) + { + return false; + } + + // The potential member name + memberName = (string) key.Value; + return true; + } + + internal static bool IsMappedAs(MethodInfo methodInfo) + { + return methodInfo.Name == nameof(LinqExtensionMethods.MappedAs) && + methodInfo.DeclaringType == typeof(LinqExtensionMethods); + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 886d4e0e2b1..689457a7403 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -62,6 +62,7 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private readonly Stack _handled = new Stack(); @@ -71,10 +72,11 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // The following is used for member expressions traversal. private int _memberExpressionDepth; - internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } public Expression Transform(Expression expression) @@ -329,7 +331,7 @@ protected override Expression VisitMember(MemberExpression expression) { // Don't add joins for things like a.B == a.C where B and C are entities. // We only need to join B when there's something like a.B.D. - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); if (_memberExpressionDepth > 0 && _joiner.CanAddJoin(expression)) { diff --git a/src/NHibernate/Multi/LinqBatchItem.cs b/src/NHibernate/Multi/LinqBatchItem.cs index 733b3115505..9e9f4ebc72a 100644 --- a/src/NHibernate/Multi/LinqBatchItem.cs +++ b/src/NHibernate/Multi/LinqBatchItem.cs @@ -48,7 +48,7 @@ private static LinqBatchItem GetForQuery(IQueryable query, Exp /// Result type public partial class LinqBatchItem : QueryBatchItem, ILinqBatchItem { - private readonly Delegate _postExecuteTransformer; + private readonly PostResultTransformer _postExecuteTransformer; private readonly System.Type _resultTypeOverride; public LinqBatchItem(IQuery query) : base(query) @@ -57,7 +57,7 @@ public LinqBatchItem(IQuery query) : base(query) internal LinqBatchItem(IQuery query, NhLinqExpression linq) : base(query) { - _postExecuteTransformer = linq.ExpressionToHqlTranslationResults.PostExecuteTransformer; + _postExecuteTransformer = linq.ExpressionToHqlTranslationResults.PostResultTransformer; _resultTypeOverride = linq.ExpressionToHqlTranslationResults.ExecuteResultTypeOverride; } @@ -88,7 +88,7 @@ protected override List DoGetResults() private List GetTransformedResults(IList transformerList) { - var res = _postExecuteTransformer.DynamicInvoke(transformerList.AsQueryable()); + var res = _postExecuteTransformer.Transform(transformerList.AsQueryable()); return new List { (T) res diff --git a/src/NHibernate/Param/NamedListParameter.cs b/src/NHibernate/Param/NamedListParameter.cs new file mode 100644 index 00000000000..30fd53c8e15 --- /dev/null +++ b/src/NHibernate/Param/NamedListParameter.cs @@ -0,0 +1,17 @@ +using NHibernate.Type; + +namespace NHibernate.Param +{ + internal class NamedListParameter : NamedParameter + { + public NamedListParameter(string name, object value, IType elementType) : base(name, value, elementType) + { + } + + internal NamedListParameter(string name, object value, IType type, int index) : base(name, value, type, index) + { + } + + public override bool IsCollection => true; + } +} diff --git a/src/NHibernate/Param/NamedParameter.cs b/src/NHibernate/Param/NamedParameter.cs index b42f69925f0..2285e7ea819 100644 --- a/src/NHibernate/Param/NamedParameter.cs +++ b/src/NHibernate/Param/NamedParameter.cs @@ -11,10 +11,19 @@ public NamedParameter(string name, object value, IType type) Type = type; } + internal NamedParameter(string name, object value, IType type, int index) : this(name, value, type) + { + Index = index; + } + public string Name { get; private set; } public object Value { get; internal set; } public IType Type { get; internal set; } + internal int Index { get; } + + public virtual bool IsCollection => false; + public bool Equals(NamedParameter other) { if (ReferenceEquals(null, other)) @@ -38,4 +47,4 @@ public override int GetHashCode() return (Name != null ? Name.GetHashCode() : 0); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 08a60aeeb66..eebee36c8dd 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -15,6 +16,7 @@ using NHibernate.Type; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; namespace NHibernate.Util { @@ -30,6 +32,33 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } +#if NETCOREAPP2_0 + /// + /// Try to retrieve from a reduced expression. + /// + /// The reduced dynamic expression. + /// The out binder parameter. + /// Whether the binder was found. + internal static bool TryGetDynamicMemberBinder(InvocationExpression expression, out GetMemberBinder memberBinder) + { + // This is an ugly workaround for dynamic expressions in .NET Core. In .NET Core a dynamic expression is reduced + // when first visited by a expression visitor that is not a DynamicExpressionVisitor, where in .NET Framework it is never reduced. + // As RelinqExpressionVisitor does not extend DynamicExpressionVisitor, we will always have a reduced dynamic expression in .NET Core. + // Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression + if (expression.Arguments.Count == 2 && + expression.Arguments[0] is ConstantExpression constant && + constant.Value is CallSite site && + site.Binder is GetMemberBinder binder) + { + memberBinder = binder; + return true; + } + + memberBinder = null; + return false; + } +#endif + /// /// Check whether the given expression represent a variable. /// @@ -635,6 +664,34 @@ protected override Expression VisitMember(MemberExpression node) return base.Visit(node.Expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression node) + { + if (TryGetDynamicMemberBinder(node, out var binder)) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[1]); + } + + return base.VisitInvocation(node); + } +#endif + + protected override Expression VisitDynamic(DynamicExpression node) + { + if (node.Binder is GetMemberBinder binder) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[0]); + } + + return Visit(node); + } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) { if (node.ReferencedQuerySource is IFromClause fromClause) @@ -721,6 +778,14 @@ protected override Expression VisitMethodCall(MethodCallExpression node) ); } + if (VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(node, out var memberName)) + { + _memberPaths.Push(new MemberMetadata(memberName, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Object); + } + return Visit(node); } diff --git a/src/NHibernate/Util/ParameterHelper.cs b/src/NHibernate/Util/ParameterHelper.cs new file mode 100644 index 00000000000..d0b6bd14625 --- /dev/null +++ b/src/NHibernate/Util/ParameterHelper.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Proxy; +using NHibernate.Type; + +namespace NHibernate.Util +{ + internal static class ParameterHelper + { + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (param == null) + { + return null; + } + + if (param is IEnumerable enumerable && isCollection) + { + var firstValue = enumerable.Cast().FirstOrDefault(); + return firstValue == null + ? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory) + : TryGuessType(firstValue, sessionFactory, false); + } + + var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType GuessType(object param, ISessionFactoryImplementor sessionFactory) + { + if (param == null) + { + throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + } + + System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return GuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (clazz == null) + { + return null; + } + + if (isCollection) + { + return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, false); + } + + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + return TryGuessType(clazz, sessionFactory) ?? + throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + return null; + } + + var type = TypeFactory.HeuristicType(clazz); + if (type == null || type is SerializableType) + { + if (sessionFactory.TryGetEntityPersister(clazz.FullName) != null) + { + return NHibernateUtil.Entity(clazz); + } + } + + return type; + } + } +} diff --git a/src/NHibernate/Util/ReflectionCache.cs b/src/NHibernate/Util/ReflectionCache.cs index c40a395f98d..47fde15950d 100644 --- a/src/NHibernate/Util/ReflectionCache.cs +++ b/src/NHibernate/Util/ReflectionCache.cs @@ -54,6 +54,11 @@ internal static class EnumerableMethods internal static readonly MethodInfo ToListDefinition = ReflectHelper.FastGetMethodDefinition(Enumerable.ToList, default(IEnumerable)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), default(int)); } internal static class MethodBaseMethods @@ -215,6 +220,11 @@ internal static class QueryableMethods ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); internal static readonly MethodInfo AverageWithSelectorOfNullableDecimalDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), default(int)); } internal static class TypeMethods