diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs new file mode 100644 index 00000000000..4fbebe3e78b --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -0,0 +1,398 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text.RegularExpressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + using System.Threading; + [TestFixture] + public class ParameterTestsAsync : LinqTestCase + { + [Test] + public async Task UsingArrayParameterTwiceAsync() + { + var ids = new[] {11008, 11019, 11039}; + await (AssertTotalParametersAsync( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)), + ids.Length, + 1)); + } + + [Test] + public async Task UsingTwoArrayParametersAsync() + { + var ids = new[] {11008, 11019, 11039}; + var ids2 = new[] {11008, 11019, 11039}; + await (AssertTotalParametersAsync( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)), + ids.Length + ids2.Length, + 2)); + } + + [Test] + public async Task UsingListParameterTwiceAsync() + { + var ids = new List {11008, 11019, 11039}; + await (AssertTotalParametersAsync( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)), + ids.Count, + 1)); + } + + [Test] + public async Task UsingTwoListParametersAsync() + { + var ids = new List {11008, 11019, 11039}; + var ids2 = new List {11008, 11019, 11039}; + await (AssertTotalParametersAsync( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)), + ids.Count + ids2.Count, + 2)); + } + + [Test] + public async Task UsingEntityParameterTwiceAsync() + { + var order = await (db.Orders.FirstAsync()); + await (AssertTotalParametersAsync( + db.Orders.Where(o => o == order && o != order), + 1)); + } + + [Test] + public async Task UsingTwoEntityParametersAsync() + { + var order = await (db.Orders.FirstAsync()); + var order2 = await (db.Orders.FirstAsync()); + await (AssertTotalParametersAsync( + db.Orders.Where(o => o == order && o != order2), + 2)); + } + + [Test] + public async Task UsingValueTypeParameterTwiceAsync() + { + var value = 1; + await (AssertTotalParametersAsync( + db.Orders.Where(o => o.OrderId == value && o.OrderId != value), + 1)); + } + + [Test] + public async Task UsingNegateValueTypeParameterTwiceAsync() + { + var value = 1; + await (AssertTotalParametersAsync( + db.Orders.Where(o => o.OrderId == -value && o.OrderId != -value), + 1)); + } + + [Test] + public async Task UsingNegateValueTypeParameterAsync() + { + var value = 1; + await (AssertTotalParametersAsync( + db.Orders.Where(o => o.OrderId == value && o.OrderId != -value), + 1)); + } + + [Test] + public async Task UsingValueTypeParameterInArrayAsync() + { + var id = 11008; + await (AssertTotalParametersAsync( + db.Orders.Where(o => new[] {id, 11019}.Contains(o.OrderId) && new[] {id, 11019}.Contains(o.OrderId)), + 4, + 2)); + } + + [Test] + public async Task UsingTwoValueTypeParametersAsync() + { + var value = 1; + var value2 = 1; + await (AssertTotalParametersAsync( + db.Orders.Where(o => o.OrderId == value && o.OrderId != value2), + 2)); + } + + [Test] + public async Task UsingStringParameterTwiceAsync() + { + var value = "test"; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == value && o.Name != value), + 1)); + } + + [Test] + public async Task UsingTwoStringParametersAsync() + { + var value = "test"; + var value2 = "test"; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == value && o.Name != value2), + 2)); + } + + [Test] + public async Task UsingObjectPropertyParameterTwiceAsync() + { + var value = new Product {Name = "test"}; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == value.Name && o.Name != value.Name), + 1)); + } + + [Test] + public async Task UsingTwoObjectPropertyParametersAsync() + { + var value = new Product {Name = "test"}; + var value2 = new Product {Name = "test"}; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == value.Name && o.Name != value2.Name), + 2)); + } + + [Test] + public async Task UsingParameterInWhereSkipTakeAsync() + { + var value3 = 1; + var q1 = db.Products.Where(o => o.ProductId < value3).Take(value3).Skip(value3); + await (AssertTotalParametersAsync(q1, 3)); + } + + [Test] + public async Task UsingParameterInTwoWhereAsync() + { + var value3 = 1; + var q1 = db.Products.Where(o => o.ProductId < value3).Where(o => o.ProductId < value3); + await (AssertTotalParametersAsync(q1, 1)); + } + + [Test] + public async Task UsingObjectNestedPropertyParameterTwiceAsync() + { + var value = new Employee {Superior = new Employee {Superior = new Employee {FirstName = "test"}}}; + await (AssertTotalParametersAsync( + db.Employees.Where(o => o.FirstName == value.Superior.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName), + 1)); + } + + [Test] + public async Task UsingDifferentObjectNestedPropertyParameterAsync() + { + var value = new Employee {Superior = new Employee {FirstName = "test", Superior = new Employee {FirstName = "test"}}}; + await (AssertTotalParametersAsync( + db.Employees.Where(o => o.FirstName == value.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName), + 2)); + } + + [Test] + public async Task UsingMethodObjectPropertyParameterTwiceAsync() + { + var value = new Product {Name = "test"}; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == value.Name.Trim() && o.Name != value.Name.Trim()), + 2)); + } + + [Test] + public async Task UsingStaticMethodObjectPropertyParameterTwiceAsync() + { + var value = new Product {Name = "test"}; + await (AssertTotalParametersAsync( + db.Products.Where(o => o.Name == string.Copy(value.Name) && o.Name != string.Copy(value.Name)), + 2)); + } + + [Test] + public async Task UsingObjectPropertyParameterWithSecondLevelClosureAsync() + { + var value = new Product {Name = "test"}; + Expression> predicate = o => o.Name == value.Name && o.Name != value.Name; + await (AssertTotalParametersAsync( + db.Products.Where(predicate), + 1)); + } + + [Test] + public async Task UsingObjectPropertyParameterWithThirdLevelClosureAsync() + { + var value = new Product {Name = "test"}; + Expression> orderLinePredicate = o => o.Order.ShippedTo == value.Name && o.Order.ShippedTo != value.Name; + Expression> predicate = o => o.Name == value.Name && o.OrderLines.AsQueryable().Any(orderLinePredicate); + await (AssertTotalParametersAsync( + db.Products.Where(predicate), + 1)); + } + + [Test] + public async Task UsingParameterInDMLInsertIntoFourTimesAsync() + { + var value = "test"; + await (AssertTotalParametersAsync( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {CustomerId = value, ContactName = value, CompanyName = value}, + 4)); + } + + [Test] + public async Task UsingFourParametersInDMLInsertIntoAsync() + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + var value4 = "test"; + await (AssertTotalParametersAsync( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value}, + 4)); + } + + [Test] + public async Task UsingParameterInDMLUpdateThreeTimesAsync() + { + var value = "test"; + await (AssertTotalParametersAsync( + QueryMode.Update, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {ContactName = value, CompanyName = value}, + 3)); + } + + [Test] + public async Task UsingThreeParametersInDMLUpdateAsync() + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + await (AssertTotalParametersAsync( + QueryMode.Update, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer { ContactName = value2, CompanyName = value }, + 3)); + } + + [Test] + public async Task UsingParameterInDMLDeleteTwiceAsync() + { + var value = "test"; + await (AssertTotalParametersAsync( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value), + 2)); + } + + [Test] + public async Task UsingTwoParametersInDMLDeleteAsync() + { + var value = "test"; + var value2 = "test"; + await (AssertTotalParametersAsync( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2), + 2)); + } + + private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, CancellationToken cancellationToken = default(CancellationToken)) + { + using (var sqlSpy = new SqlLogSpy()) + { + // In case of arrays linqParameterNumber and parameterNumber will be different + Assert.That( + GetLinqExpression(query).ParameterValuesByName.Count, + Is.EqualTo(linqParameterNumber ?? parameterNumber), + "Linq expression has different number of parameters"); + + var queryPlanCacheType = typeof(QueryPlanCache); + var cache = (SoftLimitMRUCache) + queryPlanCacheType + .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + await (query.ToListAsync(cancellationToken)); + + // In case of arrays two query plans will be stored, one with an one without expended parameters + Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); + + AssertParameters(sqlSpy, parameterNumber); + } + } + + private static Task AssertTotalParametersAsync(QueryMode queryMode, IQueryable query, int parameterNumber, CancellationToken cancellationToken = default(CancellationToken)) + { + return AssertTotalParametersAsync(queryMode, query, null, parameterNumber, cancellationToken); + } + + private static async Task AssertTotalParametersAsync(QueryMode queryMode, IQueryable query, Expression> expression, int parameterNumber, CancellationToken cancellationToken = default(CancellationToken)) + { + var provider = query.Provider as INhQueryProvider; + Assert.That(provider, Is.Not.Null); + + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + using (var sqlSpy = new SqlLogSpy()) + { + Assert.That(await (provider.ExecuteDmlAsync(queryMode, dmlExpression, cancellationToken)), Is.EqualTo(0), "The DML query updated the data"); // Avoid updating the data + AssertParameters(sqlSpy, parameterNumber); + } + } + + private static void AssertParameters(SqlLogSpy sqlSpy, int parameterNumber) + { + var sqlParameters = sqlSpy.GetWholeLog().Split(';')[1]; + var matches = Regex.Matches(sqlParameters, @"([\d\w]+)[\s]+\=", RegexOptions.IgnoreCase); + + // Due to ODBC drivers not supporting parameter names, we have to do a distinct of parameter names. + var distinctParameters = matches.OfType().Select(m => m.Groups[1].Value.Trim()).Distinct().ToList(); + Assert.That(distinctParameters, Has.Count.EqualTo(parameterNumber)); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query, Expression> expression) + { + return GetLinqExpression(queryMode, DmlExpressionRewriter.PrepareExpression(query.Expression, expression)); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query) + { + return GetLinqExpression(queryMode, query.Expression); + } + + private NhLinqExpression GetLinqExpression(IQueryable query) + { + return GetLinqExpression(QueryMode.Select, query.Expression); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, Expression expression) + { + return queryMode == QueryMode.Select + ? new NhLinqExpression(expression, Sfi) + : new NhLinqDmlExpression(queryMode, expression, Sfi); + } + } +} diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 1bb6769d6ad..6b693ddbc4a 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -215,10 +215,14 @@ public void ConstantInWhereDoesNotCauseManyKeys() var q2 = (from c in db.Customers where c.CustomerId == "ANATR" select c); - var parameters1 = ExpressionParameterVisitor.Visit(q1.Expression, Sfi); - var k1 = ExpressionKeyVisitor.Visit(q1.Expression, parameters1); - var parameters2 = ExpressionParameterVisitor.Visit(q2.Expression, Sfi); - var k2 = ExpressionKeyVisitor.Visit(q2.Expression, parameters2); + var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi); + var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters); + var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1); + var k1 = ExpressionKeyVisitor.Visit(expression, parameters1); + + var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters); + var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2); + var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2); Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs new file mode 100644 index 00000000000..920fa565129 --- /dev/null +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -0,0 +1,459 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text.RegularExpressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class ParameterTests : LinqTestCase + { + [Test] + public void UsingArrayParameterTwice() + { + var ids = new[] {11008, 11019, 11039}; + AssertTotalParameters( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)), + ids.Length, + 1); + } + + [Test] + public void UsingTwoArrayParameters() + { + var ids = new[] {11008, 11019, 11039}; + var ids2 = new[] {11008, 11019, 11039}; + AssertTotalParameters( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)), + ids.Length + ids2.Length, + 2); + } + + [Test] + public void UsingListParameterTwice() + { + var ids = new List {11008, 11019, 11039}; + AssertTotalParameters( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)), + ids.Count, + 1); + } + + [Test] + public void UsingTwoListParameters() + { + var ids = new List {11008, 11019, 11039}; + var ids2 = new List {11008, 11019, 11039}; + AssertTotalParameters( + db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)), + ids.Count + ids2.Count, + 2); + } + + [Test] + public void UsingEntityParameterTwice() + { + var order = db.Orders.First(); + AssertTotalParameters( + db.Orders.Where(o => o == order && o != order), + 1); + } + + [Test] + public void UsingTwoEntityParameters() + { + var order = db.Orders.First(); + var order2 = db.Orders.First(); + AssertTotalParameters( + db.Orders.Where(o => o == order && o != order2), + 2); + } + + [Test] + public void UsingValueTypeParameterTwice() + { + var value = 1; + AssertTotalParameters( + db.Orders.Where(o => o.OrderId == value && o.OrderId != value), + 1); + } + + [Test] + public void ValidateMixingTwoParametersCacheKeys() + { + var value = 1; + var value2 = 1; + var expression1 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value && o.OrderId != value)); + var expression2 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value && o.OrderId != value2)); + var expression3 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value2 && o.OrderId != value)); + var expression4 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value2 && o.OrderId != value2)); + + Assert.That(expression1.Key, Is.Not.EqualTo(expression2.Key)); + Assert.That(expression1.Key, Is.Not.EqualTo(expression3.Key)); + Assert.That(expression1.Key, Is.EqualTo(expression4.Key)); + + Assert.That(expression2.Key, Is.EqualTo(expression3.Key)); + Assert.That(expression2.Key, Is.Not.EqualTo(expression4.Key)); + + Assert.That(expression3.Key, Is.Not.EqualTo(expression4.Key)); + } + + [Test] + public void UsingNegateValueTypeParameterTwice() + { + var value = 1; + AssertTotalParameters( + db.Orders.Where(o => o.OrderId == -value && o.OrderId != -value), + 1); + } + + [Test] + public void UsingNegateValueTypeParameter() + { + var value = 1; + AssertTotalParameters( + db.Orders.Where(o => o.OrderId == value && o.OrderId != -value), + 1); + } + + [Test] + public void UsingValueTypeParameterInArray() + { + var id = 11008; + AssertTotalParameters( + db.Orders.Where(o => new[] {id, 11019}.Contains(o.OrderId) && new[] {id, 11019}.Contains(o.OrderId)), + 4, + 2); + } + + [Test] + public void UsingTwoValueTypeParameters() + { + var value = 1; + var value2 = 1; + AssertTotalParameters( + db.Orders.Where(o => o.OrderId == value && o.OrderId != value2), + 2); + } + + [Test] + public void UsingStringParameterTwice() + { + var value = "test"; + AssertTotalParameters( + db.Products.Where(o => o.Name == value && o.Name != value), + 1); + } + + [Test] + public void UsingTwoStringParameters() + { + var value = "test"; + var value2 = "test"; + AssertTotalParameters( + db.Products.Where(o => o.Name == value && o.Name != value2), + 2); + } + + [Test] + public void UsingObjectPropertyParameterTwice() + { + var value = new Product {Name = "test"}; + AssertTotalParameters( + db.Products.Where(o => o.Name == value.Name && o.Name != value.Name), + 1); + } + + [Test] + public void UsingTwoObjectPropertyParameters() + { + var value = new Product {Name = "test"}; + var value2 = new Product {Name = "test"}; + AssertTotalParameters( + db.Products.Where(o => o.Name == value.Name && o.Name != value2.Name), + 2); + } + + [Test] + public void UsingParameterInWhereSkipTake() + { + var value3 = 1; + var q1 = db.Products.Where(o => o.ProductId < value3).Take(value3).Skip(value3); + AssertTotalParameters(q1, 3); + } + + [Test] + public void UsingParameterInTwoWhere() + { + var value3 = 1; + var q1 = db.Products.Where(o => o.ProductId < value3).Where(o => o.ProductId < value3); + AssertTotalParameters(q1, 1); + } + + [Test] + public void UsingObjectNestedPropertyParameterTwice() + { + var value = new Employee {Superior = new Employee {Superior = new Employee {FirstName = "test"}}}; + AssertTotalParameters( + db.Employees.Where(o => o.FirstName == value.Superior.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName), + 1); + } + + [Test] + public void UsingDifferentObjectNestedPropertyParameter() + { + var value = new Employee {Superior = new Employee {FirstName = "test", Superior = new Employee {FirstName = "test"}}}; + AssertTotalParameters( + db.Employees.Where(o => o.FirstName == value.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName), + 2); + } + + [Test] + public void UsingMethodObjectPropertyParameterTwice() + { + var value = new Product {Name = "test"}; + AssertTotalParameters( + db.Products.Where(o => o.Name == value.Name.Trim() && o.Name != value.Name.Trim()), + 2); + } + + [Test] + public void UsingStaticMethodObjectPropertyParameterTwice() + { + var value = new Product {Name = "test"}; + AssertTotalParameters( + db.Products.Where(o => o.Name == string.Copy(value.Name) && o.Name != string.Copy(value.Name)), + 2); + } + + [Test] + public void UsingObjectPropertyParameterWithSecondLevelClosure() + { + var value = new Product {Name = "test"}; + Expression> predicate = o => o.Name == value.Name && o.Name != value.Name; + AssertTotalParameters( + db.Products.Where(predicate), + 1); + } + + [Test] + public void UsingObjectPropertyParameterWithThirdLevelClosure() + { + var value = new Product {Name = "test"}; + Expression> orderLinePredicate = o => o.Order.ShippedTo == value.Name && o.Order.ShippedTo != value.Name; + Expression> predicate = o => o.Name == value.Name && o.OrderLines.AsQueryable().Any(orderLinePredicate); + AssertTotalParameters( + db.Products.Where(predicate), + 1); + } + + [Test] + public void UsingParameterInDMLInsertIntoFourTimes() + { + var value = "test"; + AssertTotalParameters( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {CustomerId = value, ContactName = value, CompanyName = value}, + 4); + } + + [Test] + public void UsingFourParametersInDMLInsertInto() + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + var value4 = "test"; + AssertTotalParameters( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value}, + 4); + } + + [Test] + public void DMLInsertIntoShouldHaveSameCacheKeys() + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + var value4 = "test"; + var expression1 = GetLinqExpression( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {CustomerId = value, ContactName = value, CompanyName = value}); + var expression2 = GetLinqExpression( + QueryMode.Insert, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value}); + + Assert.That(expression1.Key, Is.EqualTo(expression2.Key)); + } + + [Test] + public void UsingParameterInDMLUpdateThreeTimes() + { + var value = "test"; + AssertTotalParameters( + QueryMode.Update, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {ContactName = value, CompanyName = value}, + 3); + } + + [Test] + public void UsingThreeParametersInDMLUpdate() + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + AssertTotalParameters( + QueryMode.Update, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer { ContactName = value2, CompanyName = value }, + 3); + } + + [TestCase(QueryMode.Update)] + [TestCase(QueryMode.UpdateVersioned)] + public void DMLUpdateIntoShouldHaveSameCacheKeys(QueryMode queryMode) + { + var value = "test"; + var value2 = "test"; + var value3 = "test"; + var expression1 = GetLinqExpression( + queryMode, + db.Customers.Where(c => c.CustomerId == value), + x => new Customer {ContactName = value, CompanyName = value}); + var expression2 = GetLinqExpression( + queryMode, + db.Customers.Where(c => c.CustomerId == value3), + x => new Customer {ContactName = value2, CompanyName = value}); + + Assert.That(expression1.Key, Is.EqualTo(expression2.Key)); + } + + [Test] + public void UsingParameterInDMLDeleteTwice() + { + var value = "test"; + AssertTotalParameters( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value), + 2); + } + + [Test] + public void UsingTwoParametersInDMLDelete() + { + var value = "test"; + var value2 = "test"; + AssertTotalParameters( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2), + 2); + } + + [Test] + public void DMLDeleteShouldHaveSameCacheKeys() + { + var value = "test"; + var value2 = "test"; + var expression1 = GetLinqExpression( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value)); + var expression2 = GetLinqExpression( + QueryMode.Delete, + db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2)); + + Assert.That(expression1.Key, Is.EqualTo(expression2.Key)); + } + + private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null) + { + using (var sqlSpy = new SqlLogSpy()) + { + // In case of arrays linqParameterNumber and parameterNumber will be different + Assert.That( + GetLinqExpression(query).ParameterValuesByName.Count, + Is.EqualTo(linqParameterNumber ?? parameterNumber), + "Linq expression has different number of parameters"); + + var queryPlanCacheType = typeof(QueryPlanCache); + var cache = (SoftLimitMRUCache) + queryPlanCacheType + .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(Sfi.QueryPlanCache); + cache.Clear(); + + query.ToList(); + + // In case of arrays two query plans will be stored, one with an one without expended parameters + Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable"); + + AssertParameters(sqlSpy, parameterNumber); + } + } + + private static void AssertTotalParameters(QueryMode queryMode, IQueryable query, int parameterNumber) + { + AssertTotalParameters(queryMode, query, null, parameterNumber); + } + + private static void AssertTotalParameters(QueryMode queryMode, IQueryable query, Expression> expression, int parameterNumber) + { + var provider = query.Provider as INhQueryProvider; + Assert.That(provider, Is.Not.Null); + + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + using (var sqlSpy = new SqlLogSpy()) + { + Assert.That(provider.ExecuteDml(queryMode, dmlExpression), Is.EqualTo(0), "The DML query updated the data"); // Avoid updating the data + AssertParameters(sqlSpy, parameterNumber); + } + } + + private static void AssertParameters(SqlLogSpy sqlSpy, int parameterNumber) + { + var sqlParameters = sqlSpy.GetWholeLog().Split(';')[1]; + var matches = Regex.Matches(sqlParameters, @"([\d\w]+)[\s]+\=", RegexOptions.IgnoreCase); + + // Due to ODBC drivers not supporting parameter names, we have to do a distinct of parameter names. + var distinctParameters = matches.OfType().Select(m => m.Groups[1].Value.Trim()).Distinct().ToList(); + Assert.That(distinctParameters, Has.Count.EqualTo(parameterNumber)); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query, Expression> expression) + { + return GetLinqExpression(queryMode, DmlExpressionRewriter.PrepareExpression(query.Expression, expression)); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query) + { + return GetLinqExpression(queryMode, query.Expression); + } + + private NhLinqExpression GetLinqExpression(IQueryable query) + { + return GetLinqExpression(QueryMode.Select, query.Expression); + } + + private NhLinqExpression GetLinqExpression(QueryMode queryMode, Expression expression) + { + return queryMode == QueryMode.Select + ? new NhLinqExpression(expression, Sfi) + : new NhLinqDmlExpression(queryMode, expression, Sfi); + } + } +} diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index b65aa43f701..20610d32bad 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -773,8 +773,8 @@ private void AssertResult( expectedComponentType = expectedComponentType ?? (o => o == null); var expression = query.Expression; - NhRelinqQueryParser.PreTransform(expression, Sfi); - var constantToParameterMap = ExpressionParameterVisitor.Visit(expression, Sfi); + var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi)); + expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); var visitorParameters = new VisitorParameters( diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs index 3e0c19d4ca1..aabdea69916 100644 --- a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs +++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs @@ -35,6 +35,19 @@ public Expression Transform(BinaryExpression expression) if (!lhsIsConvertExpression && !rhsIsConvertExpression) return expression; + // Variables are not converted to constants (E.g: o.CharProperty == charVariable) + if (lhsIsConvertExpression && rhsIsConvertExpression) + { + var lhsConvertExpression = (UnaryExpression) lhs; + var rhsConvertExpression = (UnaryExpression) rhs; + if (!IsConvertCharToInt(lhsConvertExpression) || !IsConvertCharToInt(rhsConvertExpression)) + { + return expression; + } + + return Expression.MakeBinary(expression.NodeType, lhsConvertExpression.Operand, rhsConvertExpression.Operand); + } + var lhsIsConstantExpression = IsConstantExpression(lhs); var rhsIsConstantExpression = IsConstantExpression(rhs); @@ -43,7 +56,7 @@ public Expression Transform(BinaryExpression expression) var convertExpression = lhsIsConvertExpression ? (UnaryExpression)lhs : (UnaryExpression)rhs; var constantExpression = lhsIsConstantExpression ? (ConstantExpression)lhs : (ConstantExpression)rhs; - if (convertExpression.Type == typeof(int) && convertExpression.Operand.Type == typeof(char) && constantExpression.Type == typeof(int)) + if (IsConvertCharToInt(convertExpression) && constantExpression.Type == typeof(int)) { var constant = Expression.Constant(Convert.ToChar((int)constantExpression.Value)); @@ -56,6 +69,11 @@ public Expression Transform(BinaryExpression expression) return expression; } + private static bool IsConvertCharToInt(UnaryExpression expression) + { + return expression.Type == typeof(int) && expression.Operand.Type == typeof(char); + } + private static bool IsConvertExpression(Expression expression) { return (expression.NodeType == ExpressionType.Convert); @@ -71,4 +89,4 @@ public ExpressionType[] SupportedExpressionTypes get { return _supportedExpressionTypes; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/NhLinqDmlExpression.cs b/src/NHibernate/Linq/NhLinqDmlExpression.cs index 1c5bd7bb20c..b246e86caef 100644 --- a/src/NHibernate/Linq/NhLinqDmlExpression.cs +++ b/src/NHibernate/Linq/NhLinqDmlExpression.cs @@ -5,18 +5,15 @@ namespace NHibernate.Linq { public class NhLinqDmlExpression : NhLinqExpression { - protected override QueryMode QueryMode { get; } - /// /// Entity type to insert or update when the expression is a DML. /// protected override System.Type TargetType => typeof(T); public NhLinqDmlExpression(QueryMode queryMode, Expression expression, ISessionFactoryImplementor sessionFactory) - : base(expression, sessionFactory) + : base(queryMode, expression, sessionFactory) { Key = $"{queryMode.ToString().ToUpperInvariant()} {Key}"; - QueryMode = queryMode; } } } diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index 918b56a37f8..817bfe459e2 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -32,14 +32,23 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; private set; } - protected virtual QueryMode QueryMode => QueryMode.Select; + protected virtual QueryMode QueryMode { get; } private readonly Expression _expression; private readonly IDictionary _constantToParameterMap; public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessionFactory) + : this(QueryMode.Select, expression, sessionFactory) { - _expression = NhRelinqQueryParser.PreTransform(expression, sessionFactory); + } + + internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFactoryImplementor sessionFactory) + { + QueryMode = queryMode; + var preTransformResult = NhRelinqQueryParser.PreTransform( + expression, + new PreTransformationParameters(queryMode, sessionFactory)); + _expression = preTransformResult.Expression; // We want logging to be as close as possible to the original expression sent from the // application. But if we log before partial evaluation done in PreTransform, the log won't @@ -47,9 +56,9 @@ public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessio // referenced from the main query. LinqLogging.LogExpression("Expression (partially evaluated)", _expression); - _constantToParameterMap = ExpressionParameterVisitor.Visit(ref _expression, sessionFactory); + _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap); - ParameterValuesByName = _constantToParameterMap.Values.ToDictionary(p => p.Name, + ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name, p => System.Tuple.Create(p.Value, p.Type)); Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index bc3e89c598d..56406823ad2 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -7,6 +7,7 @@ using NHibernate.Engine; using NHibernate.Linq.ExpressionTransformers; using NHibernate.Linq.Visitors; +using NHibernate.Param; using NHibernate.Util; using Remotion.Linq; using Remotion.Linq.EagerFetching.Parsing; @@ -53,10 +54,12 @@ static NhRelinqQueryParser() /// /// The expression to transform. /// The transformed expression. - [Obsolete("Use overload with an additional sessionFactory parameter")] + [Obsolete("Use overload with PreTransformationParameters parameter")] public static Expression PreTransform(Expression expression) { - return PreTransform(expression, null); + // In order to keep the old behavior use a DML query mode to skip detecting variables, + // which will then generate parameters for each constant expression + return PreTransform(expression, new PreTransformationParameters(QueryMode.Delete, null)).Expression; } /// @@ -64,13 +67,20 @@ public static Expression PreTransform(Expression expression) /// expression key computing and parsing. /// /// The expression to transform. - /// The session factory. - /// The transformed expression. - public static Expression PreTransform(Expression expression, ISessionFactoryImplementor sessionFactory) + /// The parameters used in the transformation process. + /// that contains the transformed expression. + public static PreTransformationResult PreTransform(Expression expression, PreTransformationParameters parameters) { - var partiallyEvaluatedExpression = - NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, sessionFactory); - return PreProcessor.Process(partiallyEvaluatedExpression); + parameters.EvaluatableExpressionFilter = new NhEvaluatableExpressionFilter(parameters.SessionFactory); + parameters.QueryVariables = new Dictionary(); + + var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionVisitor + .EvaluateIndependentSubtrees(expression, parameters); + + return new PreTransformationResult( + PreProcessor.Process(partiallyEvaluatedExpression), + parameters.SessionFactory, + parameters.QueryVariables); } public static QueryModel Parse(Expression expression) diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index a44ac2fd398..45134248a51 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -17,6 +17,8 @@ namespace NHibernate.Linq.Visitors public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); + private readonly Dictionary _variableParameters = new Dictionary(); + private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; private static readonly MethodInfo QueryableSkipDefinition = @@ -34,25 +36,40 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor EnumerableSkipDefinition, EnumerableTakeDefinition }; + // Since v5.3 + [Obsolete("Please use overload with preTransformationResult parameter instead.")] public ExpressionParameterVisitor(ISessionFactoryImplementor sessionFactory) { _sessionFactory = sessionFactory; } - public static IDictionary Visit(Expression expression, ISessionFactoryImplementor sessionFactory) + public ExpressionParameterVisitor(PreTransformationResult preTransformationResult) { - return Visit(ref expression, sessionFactory); + _sessionFactory = preTransformationResult.SessionFactory; + _queryVariables = preTransformationResult.QueryVariables; } - internal static IDictionary Visit(ref Expression expression, ISessionFactoryImplementor sessionFactory) + // Since v5.3 + [Obsolete("Please use overload with preTransformationResult parameter instead.")] + public static IDictionary Visit(Expression expression, ISessionFactoryImplementor sessionFactory) { var visitor = new ExpressionParameterVisitor(sessionFactory); - - expression = visitor.Visit(expression); + visitor.Visit(expression); return visitor._parameters; } + public static Expression Visit( + PreTransformationResult preTransformationResult, + out IDictionary parameters) + { + var visitor = new ExpressionParameterVisitor(preTransformationResult); + var expression = visitor.Visit(preTransformationResult.Expression); + parameters = visitor._parameters; + + return expression; + } + protected override Expression VisitMethodCall(MethodCallExpression expression) { if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) @@ -122,7 +139,23 @@ protected override Expression VisitConstant(ConstantExpression expression) // comes up, it would be nice to combine the HQL parameter type determination code // and the Expression information. - _parameters.Add(expression, new NamedParameter("p" + (_parameters.Count + 1), value, type)); + NamedParameter parameter = null; + if (_queryVariables != null && + _queryVariables.TryGetValue(expression, out var variable) && + !_variableParameters.TryGetValue(variable, out parameter)) + { + parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + _variableParameters.Add(variable, parameter); + } + + if (parameter == null) + { + parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + } + + _parameters.Add(expression, parameter); + + return base.VisitConstant(expression); } return base.VisitConstant(expression); diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index 5cfb22fa178..45ac8ffcca5 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -1,3 +1,20 @@ +// Copyright (c) rubicon IT GmbH, www.rubicon.eu +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. rubicon licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the +// License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// + using System; using System.Linq; using System.Linq.Expressions; @@ -5,47 +22,173 @@ using NHibernate.Engine; using NHibernate.Linq.Functions; using NHibernate.Util; -using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionVisitors; using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; namespace NHibernate.Linq.Visitors { - internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor, IPartialEvaluationExceptionExpressionVisitor + // Copied from Relinq and added logic for detecting and linking variables with evaluated constant expressions + /// + /// Takes an expression tree and first analyzes it for evaluatable subtrees (using ), i.e. + /// subtrees that can be pre-evaluated before actually generating the query. Examples for evaluatable subtrees are operations on constant + /// values (constant folding), access to closure variables (variables used by the LINQ query that are defined in an outer scope), or method + /// calls on known objects or their members. In a second step, it replaces all of the evaluatable subtrees (top-down and non-recursive) by + /// their evaluated counterparts. + /// + /// + /// This visitor visits each tree node at most twice: once via the for analysis and once + /// again to replace nodes if possible (unless the parent node has already been replaced). + /// + internal sealed class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor { - private readonly ISessionFactoryImplementor _sessionFactory; + #region Relinq adjusted code - internal NhPartialEvaluatingExpressionVisitor(ISessionFactoryImplementor sessionFactory) + /// + /// Takes an expression tree and finds and evaluates all its evaluatable subtrees. + /// + public static Expression EvaluateIndependentSubtrees( + Expression expressionTree, + PreTransformationParameters preTransformationParameters) { - _sessionFactory = sessionFactory; + var partialEvaluationInfo = EvaluatableTreeFindingExpressionVisitor.Analyze( + expressionTree, + preTransformationParameters.EvaluatableExpressionFilter); + var visitor = new NhPartialEvaluatingExpressionVisitor(partialEvaluationInfo, preTransformationParameters); + + return visitor.Visit(expressionTree); } + // _partialEvaluationInfo contains a list of the expressions that are safe to be evaluated. + private readonly PartialEvaluationInfo _partialEvaluationInfo; + private readonly PreTransformationParameters _preTransformationParameters; + + private NhPartialEvaluatingExpressionVisitor( + PartialEvaluationInfo partialEvaluationInfo, + PreTransformationParameters preTransformationParameters) + { + _partialEvaluationInfo = partialEvaluationInfo; + _preTransformationParameters = preTransformationParameters; + } + + public override Expression Visit(Expression expression) + { + // Only evaluate expressions which do not use any of the surrounding parameter expressions. Don't evaluate + // lambda expressions (even if you could), we want to analyze those later on. + if (expression == null) + return null; + + if (expression.NodeType == ExpressionType.Lambda || !_partialEvaluationInfo.IsEvaluatableExpression(expression)) + return base.Visit(expression); + + Expression evaluatedExpression; + try + { + evaluatedExpression = EvaluateSubtree(expression); + } + catch (Exception ex) + { + // Evaluation caused an exception. Skip evaluation of this expression and proceed as if it weren't evaluable. + var baseVisitedExpression = base.Visit(expression); + + throw new HibernateException($"Evaluation failure on {baseVisitedExpression}", ex); + } + + if (evaluatedExpression != expression) + { + evaluatedExpression = EvaluateIndependentSubtrees(evaluatedExpression, _preTransformationParameters); + } + + #region NH additions + + // When having multiple level closure, we have to evaluate each closure independently + if (evaluatedExpression is ConstantExpression constantExpression) + { + evaluatedExpression = VisitConstant(constantExpression); + } + + // Variables in expressions are never a constant, they are encapsulated as fields of a compiler generated class. + if (expression.NodeType != ExpressionType.Constant && + _preTransformationParameters.MinimizeParameters && + evaluatedExpression is ConstantExpression variableConstant && + !_preTransformationParameters.QueryVariables.ContainsKey(variableConstant) && + ExpressionsHelper.IsVariable(expression, out var path, out var closureContext)) + { + _preTransformationParameters.QueryVariables.Add(variableConstant, new QueryVariable(path, closureContext)); + } + + #endregion + + return evaluatedExpression; + } + + /// + /// Evaluates an evaluatable subtree, i.e. an independent expression tree that is compilable and executable + /// without any data being passed in. The result of the evaluation is returned as a ; if the subtree + /// is already a , no evaluation is performed. + /// + /// The subtree to be evaluated. + /// A holding the result of the evaluation. + private Expression EvaluateSubtree(Expression subtree) + { + if (subtree.NodeType == ExpressionType.Constant) + { + var constantExpression = (ConstantExpression) subtree; + var valueAsIQueryable = constantExpression.Value as IQueryable; + if (valueAsIQueryable != null && valueAsIQueryable.Expression != constantExpression) + return valueAsIQueryable.Expression; + + return constantExpression; + } + else + { + Expression> lambdaWithoutParameters = Expression.Lambda>(Expression.Convert(subtree, typeof(object))); + var compiledLambda = lambdaWithoutParameters.Compile(); + + object value = compiledLambda(); + return Expression.Constant(value, subtree.Type); + } + } + + #endregion + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value is Expression value) { - return EvaluateIndependentSubtrees(value, _sessionFactory); + return EvaluateIndependentSubtrees(value, _preTransformationParameters); } - return base.VisitConstant(expression); } + } - public static Expression EvaluateIndependentSubtrees( - Expression expression, - ISessionFactoryImplementor sessionFactory) + internal struct QueryVariable : IEquatable + { + public QueryVariable(string path, object closureContext) + { + Path = path; + ClosureContext = closureContext; + } + + public string Path { get; } + + public object ClosureContext { get; } + + public override bool Equals(object obj) + { + return obj is QueryVariable other && Equals(other); + } + + public override int GetHashCode() { - var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees( - expression, - new NhEvaluatableExpressionFilter(sessionFactory)); - return new NhPartialEvaluatingExpressionVisitor(sessionFactory).Visit(evaluatedExpression); + unchecked + { + return (Path.GetHashCode() * 397) ^ ClosureContext.GetHashCode(); + } } - public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpression partialEvaluationExceptionExpression) + public bool Equals(QueryVariable other) { - throw new HibernateException( - $"Evaluation failure on {partialEvaluationExceptionExpression.EvaluatedExpression}", - partialEvaluationExceptionExpression.Exception); + return Path == other.Path && ReferenceEquals(ClosureContext, other.ClosureContext); } } @@ -68,6 +211,11 @@ public override bool IsEvaluatableConstant(ConstantExpression node) return base.IsEvaluatableConstant(node); } + public override bool IsEvaluatableUnary(UnaryExpression node) + { + return !ExpressionsHelper.IsVariable(node.Operand, out _, out _); + } + public override bool IsEvaluatableMember(MemberExpression node) { if (node == null) diff --git a/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs b/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs new file mode 100644 index 00000000000..3f582eef531 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs @@ -0,0 +1,51 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Engine; +using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; + +namespace NHibernate.Linq.Visitors +{ + /// + /// Contains the information needed by to perform an early transformation. + /// + public class PreTransformationParameters + { + /// + /// The default constructor. + /// + /// The query mode of the expression to pre-transform. + /// The session factory used in the pre-transform process. + public PreTransformationParameters(QueryMode queryMode, ISessionFactoryImplementor sessionFactory) + { + QueryMode = queryMode; + SessionFactory = sessionFactory; + // Skip detecting variables for DML queries as HQL does not support reusing parameters for them. + MinimizeParameters = QueryMode == QueryMode.Select; + } + + /// + /// The query mode of the expression to pre-transform. + /// + public QueryMode QueryMode { get; } + + /// + /// The session factory used in the pre-transform process. + /// + public ISessionFactoryImplementor SessionFactory { get; } + + /// + /// Whether to minimize the number of parameters for variables. + /// + public bool MinimizeParameters { get; set; } + + /// + /// The filter which decides whether a part of the expression will be pre-evalauted or not. + /// + internal IEvaluatableExpressionFilter EvaluatableExpressionFilter { get; set; } + + /// + /// A dictionary of that were evaluated from variables. + /// + internal IDictionary QueryVariables { get; set; } + } +} diff --git a/src/NHibernate/Linq/Visitors/PreTransformationResult.cs b/src/NHibernate/Linq/Visitors/PreTransformationResult.cs new file mode 100644 index 00000000000..6f55ddc7bba --- /dev/null +++ b/src/NHibernate/Linq/Visitors/PreTransformationResult.cs @@ -0,0 +1,37 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Engine; + +namespace NHibernate.Linq.Visitors +{ + /// + /// The result of method. + /// + public class PreTransformationResult + { + internal PreTransformationResult( + Expression expression, + ISessionFactoryImplementor sessionFactory, + IDictionary queryVariables) + { + Expression = expression; + SessionFactory = sessionFactory; + QueryVariables = queryVariables; + } + + /// + /// The transformed expression. + /// + public Expression Expression { get; } + + /// + /// The session factory used in the pre-transform process. + /// + public ISessionFactoryImplementor SessionFactory { get; } + + /// + /// A dictionary of that were evaluated from variables. + /// + internal IDictionary QueryVariables { get; } + } +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs index 17fc7850425..169b1211eb5 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs @@ -59,8 +59,9 @@ private static HqlAlias GetFromAlias(HqlTreeNode node) private static bool IsEmptyList(HqlParameter source, VisitorParameters parameters) { var parameterName = source.NodesPreOrder.Single(n => n is HqlIdent).AstNode.Text; - var parameterValue = parameters.ConstantToParameterMap.Single(p => p.Value.Name == parameterName).Key.Value; + // Multiple constants may be linked to the same parameter, take the first matching parameter + var parameterValue = parameters.ConstantToParameterMap.First(p => p.Value.Name == parameterName).Key.Value; return !((IEnumerable)parameterValue).Cast().Any(); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 376a42ccda0..92476759ce7 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using NHibernate.Engine; using NHibernate.Linq; using NHibernate.Linq.Expressions; @@ -28,6 +29,49 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } + /// + /// Check whether the given expression represent a variable. + /// + /// The expression to check. + /// The path of the variable. + /// The closure context where the variable is stored. + /// Whether the expression represents a variable. + internal static bool IsVariable(Expression expression, out string path, out object closureContext) + { + Expression childExpression; + string currentPath; + switch (expression) + { + case MemberExpression memberExpression: + childExpression = memberExpression.Expression; + currentPath = memberExpression.Member.Name; + break; + case ConstantExpression constantExpression: + path = null; + if (constantExpression.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) && + Attribute.IsDefined(constantExpression.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + closureContext = constantExpression.Value; + return true; + } + + closureContext = null; + return false; + default: + path = null; + closureContext = null; + return false; + } + + if (!IsVariable(childExpression, out path, out closureContext)) + { + return false; + } + + path = path != null ? $"{currentPath}_{path}" : currentPath; + return true; + } + /// /// Get the mapped type for the given expression. ///