diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs
new file mode 100644
index 00000000000..4fbebe3e78b
--- /dev/null
+++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs
@@ -0,0 +1,398 @@
+//------------------------------------------------------------------------------
+//
+// This code was generated by AsyncGenerator.
+//
+// Changes to this file may cause incorrect behavior and will be lost if
+// the code is regenerated.
+//
+//------------------------------------------------------------------------------
+
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using System.Text.RegularExpressions;
+using NHibernate.DomainModel.Northwind.Entities;
+using NHibernate.Engine.Query;
+using NHibernate.Linq;
+using NHibernate.Util;
+using NUnit.Framework;
+
+namespace NHibernate.Test.Linq
+{
+ using System.Threading.Tasks;
+ using System.Threading;
+ [TestFixture]
+ public class ParameterTestsAsync : LinqTestCase
+ {
+ [Test]
+ public async Task UsingArrayParameterTwiceAsync()
+ {
+ var ids = new[] {11008, 11019, 11039};
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)),
+ ids.Length,
+ 1));
+ }
+
+ [Test]
+ public async Task UsingTwoArrayParametersAsync()
+ {
+ var ids = new[] {11008, 11019, 11039};
+ var ids2 = new[] {11008, 11019, 11039};
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)),
+ ids.Length + ids2.Length,
+ 2));
+ }
+
+ [Test]
+ public async Task UsingListParameterTwiceAsync()
+ {
+ var ids = new List {11008, 11019, 11039};
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)),
+ ids.Count,
+ 1));
+ }
+
+ [Test]
+ public async Task UsingTwoListParametersAsync()
+ {
+ var ids = new List {11008, 11019, 11039};
+ var ids2 = new List {11008, 11019, 11039};
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)),
+ ids.Count + ids2.Count,
+ 2));
+ }
+
+ [Test]
+ public async Task UsingEntityParameterTwiceAsync()
+ {
+ var order = await (db.Orders.FirstAsync());
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o == order && o != order),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingTwoEntityParametersAsync()
+ {
+ var order = await (db.Orders.FirstAsync());
+ var order2 = await (db.Orders.FirstAsync());
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o == order && o != order2),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingValueTypeParameterTwiceAsync()
+ {
+ var value = 1;
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != value),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingNegateValueTypeParameterTwiceAsync()
+ {
+ var value = 1;
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o.OrderId == -value && o.OrderId != -value),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingNegateValueTypeParameterAsync()
+ {
+ var value = 1;
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != -value),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingValueTypeParameterInArrayAsync()
+ {
+ var id = 11008;
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => new[] {id, 11019}.Contains(o.OrderId) && new[] {id, 11019}.Contains(o.OrderId)),
+ 4,
+ 2));
+ }
+
+ [Test]
+ public async Task UsingTwoValueTypeParametersAsync()
+ {
+ var value = 1;
+ var value2 = 1;
+ await (AssertTotalParametersAsync(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != value2),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingStringParameterTwiceAsync()
+ {
+ var value = "test";
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == value && o.Name != value),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingTwoStringParametersAsync()
+ {
+ var value = "test";
+ var value2 = "test";
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == value && o.Name != value2),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingObjectPropertyParameterTwiceAsync()
+ {
+ var value = new Product {Name = "test"};
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == value.Name && o.Name != value.Name),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingTwoObjectPropertyParametersAsync()
+ {
+ var value = new Product {Name = "test"};
+ var value2 = new Product {Name = "test"};
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == value.Name && o.Name != value2.Name),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingParameterInWhereSkipTakeAsync()
+ {
+ var value3 = 1;
+ var q1 = db.Products.Where(o => o.ProductId < value3).Take(value3).Skip(value3);
+ await (AssertTotalParametersAsync(q1, 3));
+ }
+
+ [Test]
+ public async Task UsingParameterInTwoWhereAsync()
+ {
+ var value3 = 1;
+ var q1 = db.Products.Where(o => o.ProductId < value3).Where(o => o.ProductId < value3);
+ await (AssertTotalParametersAsync(q1, 1));
+ }
+
+ [Test]
+ public async Task UsingObjectNestedPropertyParameterTwiceAsync()
+ {
+ var value = new Employee {Superior = new Employee {Superior = new Employee {FirstName = "test"}}};
+ await (AssertTotalParametersAsync(
+ db.Employees.Where(o => o.FirstName == value.Superior.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingDifferentObjectNestedPropertyParameterAsync()
+ {
+ var value = new Employee {Superior = new Employee {FirstName = "test", Superior = new Employee {FirstName = "test"}}};
+ await (AssertTotalParametersAsync(
+ db.Employees.Where(o => o.FirstName == value.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingMethodObjectPropertyParameterTwiceAsync()
+ {
+ var value = new Product {Name = "test"};
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == value.Name.Trim() && o.Name != value.Name.Trim()),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingStaticMethodObjectPropertyParameterTwiceAsync()
+ {
+ var value = new Product {Name = "test"};
+ await (AssertTotalParametersAsync(
+ db.Products.Where(o => o.Name == string.Copy(value.Name) && o.Name != string.Copy(value.Name)),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingObjectPropertyParameterWithSecondLevelClosureAsync()
+ {
+ var value = new Product {Name = "test"};
+ Expression> predicate = o => o.Name == value.Name && o.Name != value.Name;
+ await (AssertTotalParametersAsync(
+ db.Products.Where(predicate),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingObjectPropertyParameterWithThirdLevelClosureAsync()
+ {
+ var value = new Product {Name = "test"};
+ Expression> orderLinePredicate = o => o.Order.ShippedTo == value.Name && o.Order.ShippedTo != value.Name;
+ Expression> predicate = o => o.Name == value.Name && o.OrderLines.AsQueryable().Any(orderLinePredicate);
+ await (AssertTotalParametersAsync(
+ db.Products.Where(predicate),
+ 1));
+ }
+
+ [Test]
+ public async Task UsingParameterInDMLInsertIntoFourTimesAsync()
+ {
+ var value = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {CustomerId = value, ContactName = value, CompanyName = value},
+ 4));
+ }
+
+ [Test]
+ public async Task UsingFourParametersInDMLInsertIntoAsync()
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ var value4 = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value},
+ 4));
+ }
+
+ [Test]
+ public async Task UsingParameterInDMLUpdateThreeTimesAsync()
+ {
+ var value = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Update,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {ContactName = value, CompanyName = value},
+ 3));
+ }
+
+ [Test]
+ public async Task UsingThreeParametersInDMLUpdateAsync()
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Update,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer { ContactName = value2, CompanyName = value },
+ 3));
+ }
+
+ [Test]
+ public async Task UsingParameterInDMLDeleteTwiceAsync()
+ {
+ var value = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value),
+ 2));
+ }
+
+ [Test]
+ public async Task UsingTwoParametersInDMLDeleteAsync()
+ {
+ var value = "test";
+ var value2 = "test";
+ await (AssertTotalParametersAsync(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2),
+ 2));
+ }
+
+ private async Task AssertTotalParametersAsync(IQueryable query, int parameterNumber, int? linqParameterNumber = null, CancellationToken cancellationToken = default(CancellationToken))
+ {
+ using (var sqlSpy = new SqlLogSpy())
+ {
+ // In case of arrays linqParameterNumber and parameterNumber will be different
+ Assert.That(
+ GetLinqExpression(query).ParameterValuesByName.Count,
+ Is.EqualTo(linqParameterNumber ?? parameterNumber),
+ "Linq expression has different number of parameters");
+
+ var queryPlanCacheType = typeof(QueryPlanCache);
+ var cache = (SoftLimitMRUCache)
+ queryPlanCacheType
+ .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic)
+ .GetValue(Sfi.QueryPlanCache);
+ cache.Clear();
+
+ await (query.ToListAsync(cancellationToken));
+
+ // In case of arrays two query plans will be stored, one with an one without expended parameters
+ Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable");
+
+ AssertParameters(sqlSpy, parameterNumber);
+ }
+ }
+
+ private static Task AssertTotalParametersAsync(QueryMode queryMode, IQueryable query, int parameterNumber, CancellationToken cancellationToken = default(CancellationToken))
+ {
+ return AssertTotalParametersAsync(queryMode, query, null, parameterNumber, cancellationToken);
+ }
+
+ private static async Task AssertTotalParametersAsync(QueryMode queryMode, IQueryable query, Expression> expression, int parameterNumber, CancellationToken cancellationToken = default(CancellationToken))
+ {
+ var provider = query.Provider as INhQueryProvider;
+ Assert.That(provider, Is.Not.Null);
+
+ var dmlExpression = expression != null
+ ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression)
+ : query.Expression;
+
+ using (var sqlSpy = new SqlLogSpy())
+ {
+ Assert.That(await (provider.ExecuteDmlAsync(queryMode, dmlExpression, cancellationToken)), Is.EqualTo(0), "The DML query updated the data"); // Avoid updating the data
+ AssertParameters(sqlSpy, parameterNumber);
+ }
+ }
+
+ private static void AssertParameters(SqlLogSpy sqlSpy, int parameterNumber)
+ {
+ var sqlParameters = sqlSpy.GetWholeLog().Split(';')[1];
+ var matches = Regex.Matches(sqlParameters, @"([\d\w]+)[\s]+\=", RegexOptions.IgnoreCase);
+
+ // Due to ODBC drivers not supporting parameter names, we have to do a distinct of parameter names.
+ var distinctParameters = matches.OfType().Select(m => m.Groups[1].Value.Trim()).Distinct().ToList();
+ Assert.That(distinctParameters, Has.Count.EqualTo(parameterNumber));
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query, Expression> expression)
+ {
+ return GetLinqExpression(queryMode, DmlExpressionRewriter.PrepareExpression(query.Expression, expression));
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query)
+ {
+ return GetLinqExpression(queryMode, query.Expression);
+ }
+
+ private NhLinqExpression GetLinqExpression(IQueryable query)
+ {
+ return GetLinqExpression(QueryMode.Select, query.Expression);
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, Expression expression)
+ {
+ return queryMode == QueryMode.Select
+ ? new NhLinqExpression(expression, Sfi)
+ : new NhLinqDmlExpression(queryMode, expression, Sfi);
+ }
+ }
+}
diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs
index 1bb6769d6ad..6b693ddbc4a 100644
--- a/src/NHibernate.Test/Linq/ConstantTest.cs
+++ b/src/NHibernate.Test/Linq/ConstantTest.cs
@@ -215,10 +215,14 @@ public void ConstantInWhereDoesNotCauseManyKeys()
var q2 = (from c in db.Customers
where c.CustomerId == "ANATR"
select c);
- var parameters1 = ExpressionParameterVisitor.Visit(q1.Expression, Sfi);
- var k1 = ExpressionKeyVisitor.Visit(q1.Expression, parameters1);
- var parameters2 = ExpressionParameterVisitor.Visit(q2.Expression, Sfi);
- var k2 = ExpressionKeyVisitor.Visit(q2.Expression, parameters2);
+ var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi);
+ var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters);
+ var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1);
+ var k1 = ExpressionKeyVisitor.Visit(expression, parameters1);
+
+ var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters);
+ var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2);
+ var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2);
Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1");
Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2");
diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs
new file mode 100644
index 00000000000..920fa565129
--- /dev/null
+++ b/src/NHibernate.Test/Linq/ParameterTests.cs
@@ -0,0 +1,459 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using System.Text.RegularExpressions;
+using NHibernate.DomainModel.Northwind.Entities;
+using NHibernate.Engine.Query;
+using NHibernate.Linq;
+using NHibernate.Util;
+using NUnit.Framework;
+
+namespace NHibernate.Test.Linq
+{
+ [TestFixture]
+ public class ParameterTests : LinqTestCase
+ {
+ [Test]
+ public void UsingArrayParameterTwice()
+ {
+ var ids = new[] {11008, 11019, 11039};
+ AssertTotalParameters(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)),
+ ids.Length,
+ 1);
+ }
+
+ [Test]
+ public void UsingTwoArrayParameters()
+ {
+ var ids = new[] {11008, 11019, 11039};
+ var ids2 = new[] {11008, 11019, 11039};
+ AssertTotalParameters(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)),
+ ids.Length + ids2.Length,
+ 2);
+ }
+
+ [Test]
+ public void UsingListParameterTwice()
+ {
+ var ids = new List {11008, 11019, 11039};
+ AssertTotalParameters(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids.Contains(o.OrderId)),
+ ids.Count,
+ 1);
+ }
+
+ [Test]
+ public void UsingTwoListParameters()
+ {
+ var ids = new List {11008, 11019, 11039};
+ var ids2 = new List {11008, 11019, 11039};
+ AssertTotalParameters(
+ db.Orders.Where(o => ids.Contains(o.OrderId) && ids2.Contains(o.OrderId)),
+ ids.Count + ids2.Count,
+ 2);
+ }
+
+ [Test]
+ public void UsingEntityParameterTwice()
+ {
+ var order = db.Orders.First();
+ AssertTotalParameters(
+ db.Orders.Where(o => o == order && o != order),
+ 1);
+ }
+
+ [Test]
+ public void UsingTwoEntityParameters()
+ {
+ var order = db.Orders.First();
+ var order2 = db.Orders.First();
+ AssertTotalParameters(
+ db.Orders.Where(o => o == order && o != order2),
+ 2);
+ }
+
+ [Test]
+ public void UsingValueTypeParameterTwice()
+ {
+ var value = 1;
+ AssertTotalParameters(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != value),
+ 1);
+ }
+
+ [Test]
+ public void ValidateMixingTwoParametersCacheKeys()
+ {
+ var value = 1;
+ var value2 = 1;
+ var expression1 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value && o.OrderId != value));
+ var expression2 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value && o.OrderId != value2));
+ var expression3 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value2 && o.OrderId != value));
+ var expression4 = GetLinqExpression(db.Orders.Where(o => o.OrderId == value2 && o.OrderId != value2));
+
+ Assert.That(expression1.Key, Is.Not.EqualTo(expression2.Key));
+ Assert.That(expression1.Key, Is.Not.EqualTo(expression3.Key));
+ Assert.That(expression1.Key, Is.EqualTo(expression4.Key));
+
+ Assert.That(expression2.Key, Is.EqualTo(expression3.Key));
+ Assert.That(expression2.Key, Is.Not.EqualTo(expression4.Key));
+
+ Assert.That(expression3.Key, Is.Not.EqualTo(expression4.Key));
+ }
+
+ [Test]
+ public void UsingNegateValueTypeParameterTwice()
+ {
+ var value = 1;
+ AssertTotalParameters(
+ db.Orders.Where(o => o.OrderId == -value && o.OrderId != -value),
+ 1);
+ }
+
+ [Test]
+ public void UsingNegateValueTypeParameter()
+ {
+ var value = 1;
+ AssertTotalParameters(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != -value),
+ 1);
+ }
+
+ [Test]
+ public void UsingValueTypeParameterInArray()
+ {
+ var id = 11008;
+ AssertTotalParameters(
+ db.Orders.Where(o => new[] {id, 11019}.Contains(o.OrderId) && new[] {id, 11019}.Contains(o.OrderId)),
+ 4,
+ 2);
+ }
+
+ [Test]
+ public void UsingTwoValueTypeParameters()
+ {
+ var value = 1;
+ var value2 = 1;
+ AssertTotalParameters(
+ db.Orders.Where(o => o.OrderId == value && o.OrderId != value2),
+ 2);
+ }
+
+ [Test]
+ public void UsingStringParameterTwice()
+ {
+ var value = "test";
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == value && o.Name != value),
+ 1);
+ }
+
+ [Test]
+ public void UsingTwoStringParameters()
+ {
+ var value = "test";
+ var value2 = "test";
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == value && o.Name != value2),
+ 2);
+ }
+
+ [Test]
+ public void UsingObjectPropertyParameterTwice()
+ {
+ var value = new Product {Name = "test"};
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == value.Name && o.Name != value.Name),
+ 1);
+ }
+
+ [Test]
+ public void UsingTwoObjectPropertyParameters()
+ {
+ var value = new Product {Name = "test"};
+ var value2 = new Product {Name = "test"};
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == value.Name && o.Name != value2.Name),
+ 2);
+ }
+
+ [Test]
+ public void UsingParameterInWhereSkipTake()
+ {
+ var value3 = 1;
+ var q1 = db.Products.Where(o => o.ProductId < value3).Take(value3).Skip(value3);
+ AssertTotalParameters(q1, 3);
+ }
+
+ [Test]
+ public void UsingParameterInTwoWhere()
+ {
+ var value3 = 1;
+ var q1 = db.Products.Where(o => o.ProductId < value3).Where(o => o.ProductId < value3);
+ AssertTotalParameters(q1, 1);
+ }
+
+ [Test]
+ public void UsingObjectNestedPropertyParameterTwice()
+ {
+ var value = new Employee {Superior = new Employee {Superior = new Employee {FirstName = "test"}}};
+ AssertTotalParameters(
+ db.Employees.Where(o => o.FirstName == value.Superior.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName),
+ 1);
+ }
+
+ [Test]
+ public void UsingDifferentObjectNestedPropertyParameter()
+ {
+ var value = new Employee {Superior = new Employee {FirstName = "test", Superior = new Employee {FirstName = "test"}}};
+ AssertTotalParameters(
+ db.Employees.Where(o => o.FirstName == value.Superior.FirstName && o.FirstName != value.Superior.Superior.FirstName),
+ 2);
+ }
+
+ [Test]
+ public void UsingMethodObjectPropertyParameterTwice()
+ {
+ var value = new Product {Name = "test"};
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == value.Name.Trim() && o.Name != value.Name.Trim()),
+ 2);
+ }
+
+ [Test]
+ public void UsingStaticMethodObjectPropertyParameterTwice()
+ {
+ var value = new Product {Name = "test"};
+ AssertTotalParameters(
+ db.Products.Where(o => o.Name == string.Copy(value.Name) && o.Name != string.Copy(value.Name)),
+ 2);
+ }
+
+ [Test]
+ public void UsingObjectPropertyParameterWithSecondLevelClosure()
+ {
+ var value = new Product {Name = "test"};
+ Expression> predicate = o => o.Name == value.Name && o.Name != value.Name;
+ AssertTotalParameters(
+ db.Products.Where(predicate),
+ 1);
+ }
+
+ [Test]
+ public void UsingObjectPropertyParameterWithThirdLevelClosure()
+ {
+ var value = new Product {Name = "test"};
+ Expression> orderLinePredicate = o => o.Order.ShippedTo == value.Name && o.Order.ShippedTo != value.Name;
+ Expression> predicate = o => o.Name == value.Name && o.OrderLines.AsQueryable().Any(orderLinePredicate);
+ AssertTotalParameters(
+ db.Products.Where(predicate),
+ 1);
+ }
+
+ [Test]
+ public void UsingParameterInDMLInsertIntoFourTimes()
+ {
+ var value = "test";
+ AssertTotalParameters(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {CustomerId = value, ContactName = value, CompanyName = value},
+ 4);
+ }
+
+ [Test]
+ public void UsingFourParametersInDMLInsertInto()
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ var value4 = "test";
+ AssertTotalParameters(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value},
+ 4);
+ }
+
+ [Test]
+ public void DMLInsertIntoShouldHaveSameCacheKeys()
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ var value4 = "test";
+ var expression1 = GetLinqExpression(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {CustomerId = value, ContactName = value, CompanyName = value});
+ var expression2 = GetLinqExpression(
+ QueryMode.Insert,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer {CustomerId = value4, ContactName = value2, CompanyName = value});
+
+ Assert.That(expression1.Key, Is.EqualTo(expression2.Key));
+ }
+
+ [Test]
+ public void UsingParameterInDMLUpdateThreeTimes()
+ {
+ var value = "test";
+ AssertTotalParameters(
+ QueryMode.Update,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {ContactName = value, CompanyName = value},
+ 3);
+ }
+
+ [Test]
+ public void UsingThreeParametersInDMLUpdate()
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ AssertTotalParameters(
+ QueryMode.Update,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer { ContactName = value2, CompanyName = value },
+ 3);
+ }
+
+ [TestCase(QueryMode.Update)]
+ [TestCase(QueryMode.UpdateVersioned)]
+ public void DMLUpdateIntoShouldHaveSameCacheKeys(QueryMode queryMode)
+ {
+ var value = "test";
+ var value2 = "test";
+ var value3 = "test";
+ var expression1 = GetLinqExpression(
+ queryMode,
+ db.Customers.Where(c => c.CustomerId == value),
+ x => new Customer {ContactName = value, CompanyName = value});
+ var expression2 = GetLinqExpression(
+ queryMode,
+ db.Customers.Where(c => c.CustomerId == value3),
+ x => new Customer {ContactName = value2, CompanyName = value});
+
+ Assert.That(expression1.Key, Is.EqualTo(expression2.Key));
+ }
+
+ [Test]
+ public void UsingParameterInDMLDeleteTwice()
+ {
+ var value = "test";
+ AssertTotalParameters(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value),
+ 2);
+ }
+
+ [Test]
+ public void UsingTwoParametersInDMLDelete()
+ {
+ var value = "test";
+ var value2 = "test";
+ AssertTotalParameters(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2),
+ 2);
+ }
+
+ [Test]
+ public void DMLDeleteShouldHaveSameCacheKeys()
+ {
+ var value = "test";
+ var value2 = "test";
+ var expression1 = GetLinqExpression(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value));
+ var expression2 = GetLinqExpression(
+ QueryMode.Delete,
+ db.Customers.Where(c => c.CustomerId == value && c.CompanyName == value2));
+
+ Assert.That(expression1.Key, Is.EqualTo(expression2.Key));
+ }
+
+ private void AssertTotalParameters(IQueryable query, int parameterNumber, int? linqParameterNumber = null)
+ {
+ using (var sqlSpy = new SqlLogSpy())
+ {
+ // In case of arrays linqParameterNumber and parameterNumber will be different
+ Assert.That(
+ GetLinqExpression(query).ParameterValuesByName.Count,
+ Is.EqualTo(linqParameterNumber ?? parameterNumber),
+ "Linq expression has different number of parameters");
+
+ var queryPlanCacheType = typeof(QueryPlanCache);
+ var cache = (SoftLimitMRUCache)
+ queryPlanCacheType
+ .GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic)
+ .GetValue(Sfi.QueryPlanCache);
+ cache.Clear();
+
+ query.ToList();
+
+ // In case of arrays two query plans will be stored, one with an one without expended parameters
+ Assert.That(cache, Has.Count.EqualTo(linqParameterNumber.HasValue ? 2 : 1), "Query should be cacheable");
+
+ AssertParameters(sqlSpy, parameterNumber);
+ }
+ }
+
+ private static void AssertTotalParameters(QueryMode queryMode, IQueryable query, int parameterNumber)
+ {
+ AssertTotalParameters(queryMode, query, null, parameterNumber);
+ }
+
+ private static void AssertTotalParameters(QueryMode queryMode, IQueryable query, Expression> expression, int parameterNumber)
+ {
+ var provider = query.Provider as INhQueryProvider;
+ Assert.That(provider, Is.Not.Null);
+
+ var dmlExpression = expression != null
+ ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression)
+ : query.Expression;
+
+ using (var sqlSpy = new SqlLogSpy())
+ {
+ Assert.That(provider.ExecuteDml(queryMode, dmlExpression), Is.EqualTo(0), "The DML query updated the data"); // Avoid updating the data
+ AssertParameters(sqlSpy, parameterNumber);
+ }
+ }
+
+ private static void AssertParameters(SqlLogSpy sqlSpy, int parameterNumber)
+ {
+ var sqlParameters = sqlSpy.GetWholeLog().Split(';')[1];
+ var matches = Regex.Matches(sqlParameters, @"([\d\w]+)[\s]+\=", RegexOptions.IgnoreCase);
+
+ // Due to ODBC drivers not supporting parameter names, we have to do a distinct of parameter names.
+ var distinctParameters = matches.OfType().Select(m => m.Groups[1].Value.Trim()).Distinct().ToList();
+ Assert.That(distinctParameters, Has.Count.EqualTo(parameterNumber));
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query, Expression> expression)
+ {
+ return GetLinqExpression(queryMode, DmlExpressionRewriter.PrepareExpression(query.Expression, expression));
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, IQueryable query)
+ {
+ return GetLinqExpression(queryMode, query.Expression);
+ }
+
+ private NhLinqExpression GetLinqExpression(IQueryable query)
+ {
+ return GetLinqExpression(QueryMode.Select, query.Expression);
+ }
+
+ private NhLinqExpression GetLinqExpression(QueryMode queryMode, Expression expression)
+ {
+ return queryMode == QueryMode.Select
+ ? new NhLinqExpression(expression, Sfi)
+ : new NhLinqDmlExpression(queryMode, expression, Sfi);
+ }
+ }
+}
diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs
index b65aa43f701..20610d32bad 100644
--- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs
+++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs
@@ -773,8 +773,8 @@ private void AssertResult(
expectedComponentType = expectedComponentType ?? (o => o == null);
var expression = query.Expression;
- NhRelinqQueryParser.PreTransform(expression, Sfi);
- var constantToParameterMap = ExpressionParameterVisitor.Visit(expression, Sfi);
+ var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi));
+ expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap);
var queryModel = NhRelinqQueryParser.Parse(expression);
var requiredHqlParameters = new List();
var visitorParameters = new VisitorParameters(
diff --git a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs
index 3e0c19d4ca1..aabdea69916 100644
--- a/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs
+++ b/src/NHibernate/Linq/ExpressionTransformers/RemoveCharToIntConversion.cs
@@ -35,6 +35,19 @@ public Expression Transform(BinaryExpression expression)
if (!lhsIsConvertExpression && !rhsIsConvertExpression) return expression;
+ // Variables are not converted to constants (E.g: o.CharProperty == charVariable)
+ if (lhsIsConvertExpression && rhsIsConvertExpression)
+ {
+ var lhsConvertExpression = (UnaryExpression) lhs;
+ var rhsConvertExpression = (UnaryExpression) rhs;
+ if (!IsConvertCharToInt(lhsConvertExpression) || !IsConvertCharToInt(rhsConvertExpression))
+ {
+ return expression;
+ }
+
+ return Expression.MakeBinary(expression.NodeType, lhsConvertExpression.Operand, rhsConvertExpression.Operand);
+ }
+
var lhsIsConstantExpression = IsConstantExpression(lhs);
var rhsIsConstantExpression = IsConstantExpression(rhs);
@@ -43,7 +56,7 @@ public Expression Transform(BinaryExpression expression)
var convertExpression = lhsIsConvertExpression ? (UnaryExpression)lhs : (UnaryExpression)rhs;
var constantExpression = lhsIsConstantExpression ? (ConstantExpression)lhs : (ConstantExpression)rhs;
- if (convertExpression.Type == typeof(int) && convertExpression.Operand.Type == typeof(char) && constantExpression.Type == typeof(int))
+ if (IsConvertCharToInt(convertExpression) && constantExpression.Type == typeof(int))
{
var constant = Expression.Constant(Convert.ToChar((int)constantExpression.Value));
@@ -56,6 +69,11 @@ public Expression Transform(BinaryExpression expression)
return expression;
}
+ private static bool IsConvertCharToInt(UnaryExpression expression)
+ {
+ return expression.Type == typeof(int) && expression.Operand.Type == typeof(char);
+ }
+
private static bool IsConvertExpression(Expression expression)
{
return (expression.NodeType == ExpressionType.Convert);
@@ -71,4 +89,4 @@ public ExpressionType[] SupportedExpressionTypes
get { return _supportedExpressionTypes; }
}
}
-}
\ No newline at end of file
+}
diff --git a/src/NHibernate/Linq/NhLinqDmlExpression.cs b/src/NHibernate/Linq/NhLinqDmlExpression.cs
index 1c5bd7bb20c..b246e86caef 100644
--- a/src/NHibernate/Linq/NhLinqDmlExpression.cs
+++ b/src/NHibernate/Linq/NhLinqDmlExpression.cs
@@ -5,18 +5,15 @@ namespace NHibernate.Linq
{
public class NhLinqDmlExpression : NhLinqExpression
{
- protected override QueryMode QueryMode { get; }
-
///
/// Entity type to insert or update when the expression is a DML.
///
protected override System.Type TargetType => typeof(T);
public NhLinqDmlExpression(QueryMode queryMode, Expression expression, ISessionFactoryImplementor sessionFactory)
- : base(expression, sessionFactory)
+ : base(queryMode, expression, sessionFactory)
{
Key = $"{queryMode.ToString().ToUpperInvariant()} {Key}";
- QueryMode = queryMode;
}
}
}
diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs
index 918b56a37f8..817bfe459e2 100644
--- a/src/NHibernate/Linq/NhLinqExpression.cs
+++ b/src/NHibernate/Linq/NhLinqExpression.cs
@@ -32,14 +32,23 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression
public ExpressionToHqlTranslationResults ExpressionToHqlTranslationResults { get; private set; }
- protected virtual QueryMode QueryMode => QueryMode.Select;
+ protected virtual QueryMode QueryMode { get; }
private readonly Expression _expression;
private readonly IDictionary _constantToParameterMap;
public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessionFactory)
+ : this(QueryMode.Select, expression, sessionFactory)
{
- _expression = NhRelinqQueryParser.PreTransform(expression, sessionFactory);
+ }
+
+ internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFactoryImplementor sessionFactory)
+ {
+ QueryMode = queryMode;
+ var preTransformResult = NhRelinqQueryParser.PreTransform(
+ expression,
+ new PreTransformationParameters(queryMode, sessionFactory));
+ _expression = preTransformResult.Expression;
// We want logging to be as close as possible to the original expression sent from the
// application. But if we log before partial evaluation done in PreTransform, the log won't
@@ -47,9 +56,9 @@ public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessio
// referenced from the main query.
LinqLogging.LogExpression("Expression (partially evaluated)", _expression);
- _constantToParameterMap = ExpressionParameterVisitor.Visit(ref _expression, sessionFactory);
+ _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap);
- ParameterValuesByName = _constantToParameterMap.Values.ToDictionary(p => p.Name,
+ ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name,
p => System.Tuple.Create(p.Value, p.Type));
Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap);
diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs
index bc3e89c598d..56406823ad2 100644
--- a/src/NHibernate/Linq/NhRelinqQueryParser.cs
+++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs
@@ -7,6 +7,7 @@
using NHibernate.Engine;
using NHibernate.Linq.ExpressionTransformers;
using NHibernate.Linq.Visitors;
+using NHibernate.Param;
using NHibernate.Util;
using Remotion.Linq;
using Remotion.Linq.EagerFetching.Parsing;
@@ -53,10 +54,12 @@ static NhRelinqQueryParser()
///
/// The expression to transform.
/// The transformed expression.
- [Obsolete("Use overload with an additional sessionFactory parameter")]
+ [Obsolete("Use overload with PreTransformationParameters parameter")]
public static Expression PreTransform(Expression expression)
{
- return PreTransform(expression, null);
+ // In order to keep the old behavior use a DML query mode to skip detecting variables,
+ // which will then generate parameters for each constant expression
+ return PreTransform(expression, new PreTransformationParameters(QueryMode.Delete, null)).Expression;
}
///
@@ -64,13 +67,20 @@ public static Expression PreTransform(Expression expression)
/// expression key computing and parsing.
///
/// The expression to transform.
- /// The session factory.
- /// The transformed expression.
- public static Expression PreTransform(Expression expression, ISessionFactoryImplementor sessionFactory)
+ /// The parameters used in the transformation process.
+ /// that contains the transformed expression.
+ public static PreTransformationResult PreTransform(Expression expression, PreTransformationParameters parameters)
{
- var partiallyEvaluatedExpression =
- NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, sessionFactory);
- return PreProcessor.Process(partiallyEvaluatedExpression);
+ parameters.EvaluatableExpressionFilter = new NhEvaluatableExpressionFilter(parameters.SessionFactory);
+ parameters.QueryVariables = new Dictionary();
+
+ var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionVisitor
+ .EvaluateIndependentSubtrees(expression, parameters);
+
+ return new PreTransformationResult(
+ PreProcessor.Process(partiallyEvaluatedExpression),
+ parameters.SessionFactory,
+ parameters.QueryVariables);
}
public static QueryModel Parse(Expression expression)
diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs
index a44ac2fd398..45134248a51 100644
--- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs
+++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs
@@ -17,6 +17,8 @@ namespace NHibernate.Linq.Visitors
public class ExpressionParameterVisitor : RelinqExpressionVisitor
{
private readonly Dictionary _parameters = new Dictionary();
+ private readonly Dictionary _variableParameters = new Dictionary();
+ private readonly IDictionary _queryVariables;
private readonly ISessionFactoryImplementor _sessionFactory;
private static readonly MethodInfo QueryableSkipDefinition =
@@ -34,25 +36,40 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor
EnumerableSkipDefinition, EnumerableTakeDefinition
};
+ // Since v5.3
+ [Obsolete("Please use overload with preTransformationResult parameter instead.")]
public ExpressionParameterVisitor(ISessionFactoryImplementor sessionFactory)
{
_sessionFactory = sessionFactory;
}
- public static IDictionary Visit(Expression expression, ISessionFactoryImplementor sessionFactory)
+ public ExpressionParameterVisitor(PreTransformationResult preTransformationResult)
{
- return Visit(ref expression, sessionFactory);
+ _sessionFactory = preTransformationResult.SessionFactory;
+ _queryVariables = preTransformationResult.QueryVariables;
}
- internal static IDictionary Visit(ref Expression expression, ISessionFactoryImplementor sessionFactory)
+ // Since v5.3
+ [Obsolete("Please use overload with preTransformationResult parameter instead.")]
+ public static IDictionary Visit(Expression expression, ISessionFactoryImplementor sessionFactory)
{
var visitor = new ExpressionParameterVisitor(sessionFactory);
-
- expression = visitor.Visit(expression);
+ visitor.Visit(expression);
return visitor._parameters;
}
+ public static Expression Visit(
+ PreTransformationResult preTransformationResult,
+ out IDictionary parameters)
+ {
+ var visitor = new ExpressionParameterVisitor(preTransformationResult);
+ var expression = visitor.Visit(preTransformationResult.Expression);
+ parameters = visitor._parameters;
+
+ return expression;
+ }
+
protected override Expression VisitMethodCall(MethodCallExpression expression)
{
if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods))
@@ -122,7 +139,23 @@ protected override Expression VisitConstant(ConstantExpression expression)
// comes up, it would be nice to combine the HQL parameter type determination code
// and the Expression information.
- _parameters.Add(expression, new NamedParameter("p" + (_parameters.Count + 1), value, type));
+ NamedParameter parameter = null;
+ if (_queryVariables != null &&
+ _queryVariables.TryGetValue(expression, out var variable) &&
+ !_variableParameters.TryGetValue(variable, out parameter))
+ {
+ parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type);
+ _variableParameters.Add(variable, parameter);
+ }
+
+ if (parameter == null)
+ {
+ parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type);
+ }
+
+ _parameters.Add(expression, parameter);
+
+ return base.VisitConstant(expression);
}
return base.VisitConstant(expression);
diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs
index 5cfb22fa178..45ac8ffcca5 100644
--- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs
+++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs
@@ -1,3 +1,20 @@
+// Copyright (c) rubicon IT GmbH, www.rubicon.eu
+//
+// See the NOTICE file distributed with this work for additional information
+// regarding copyright ownership. rubicon licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may not use this
+// file except in compliance with the License. You may obtain a copy of the
+// License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
+// under the License.
+//
+
using System;
using System.Linq;
using System.Linq.Expressions;
@@ -5,47 +22,173 @@
using NHibernate.Engine;
using NHibernate.Linq.Functions;
using NHibernate.Util;
-using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Parsing;
-using Remotion.Linq.Parsing.ExpressionVisitors;
using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation;
namespace NHibernate.Linq.Visitors
{
- internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor, IPartialEvaluationExceptionExpressionVisitor
+ // Copied from Relinq and added logic for detecting and linking variables with evaluated constant expressions
+ ///
+ /// Takes an expression tree and first analyzes it for evaluatable subtrees (using ), i.e.
+ /// subtrees that can be pre-evaluated before actually generating the query. Examples for evaluatable subtrees are operations on constant
+ /// values (constant folding), access to closure variables (variables used by the LINQ query that are defined in an outer scope), or method
+ /// calls on known objects or their members. In a second step, it replaces all of the evaluatable subtrees (top-down and non-recursive) by
+ /// their evaluated counterparts.
+ ///
+ ///
+ /// This visitor visits each tree node at most twice: once via the for analysis and once
+ /// again to replace nodes if possible (unless the parent node has already been replaced).
+ ///
+ internal sealed class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor
{
- private readonly ISessionFactoryImplementor _sessionFactory;
+ #region Relinq adjusted code
- internal NhPartialEvaluatingExpressionVisitor(ISessionFactoryImplementor sessionFactory)
+ ///
+ /// Takes an expression tree and finds and evaluates all its evaluatable subtrees.
+ ///
+ public static Expression EvaluateIndependentSubtrees(
+ Expression expressionTree,
+ PreTransformationParameters preTransformationParameters)
{
- _sessionFactory = sessionFactory;
+ var partialEvaluationInfo = EvaluatableTreeFindingExpressionVisitor.Analyze(
+ expressionTree,
+ preTransformationParameters.EvaluatableExpressionFilter);
+ var visitor = new NhPartialEvaluatingExpressionVisitor(partialEvaluationInfo, preTransformationParameters);
+
+ return visitor.Visit(expressionTree);
}
+ // _partialEvaluationInfo contains a list of the expressions that are safe to be evaluated.
+ private readonly PartialEvaluationInfo _partialEvaluationInfo;
+ private readonly PreTransformationParameters _preTransformationParameters;
+
+ private NhPartialEvaluatingExpressionVisitor(
+ PartialEvaluationInfo partialEvaluationInfo,
+ PreTransformationParameters preTransformationParameters)
+ {
+ _partialEvaluationInfo = partialEvaluationInfo;
+ _preTransformationParameters = preTransformationParameters;
+ }
+
+ public override Expression Visit(Expression expression)
+ {
+ // Only evaluate expressions which do not use any of the surrounding parameter expressions. Don't evaluate
+ // lambda expressions (even if you could), we want to analyze those later on.
+ if (expression == null)
+ return null;
+
+ if (expression.NodeType == ExpressionType.Lambda || !_partialEvaluationInfo.IsEvaluatableExpression(expression))
+ return base.Visit(expression);
+
+ Expression evaluatedExpression;
+ try
+ {
+ evaluatedExpression = EvaluateSubtree(expression);
+ }
+ catch (Exception ex)
+ {
+ // Evaluation caused an exception. Skip evaluation of this expression and proceed as if it weren't evaluable.
+ var baseVisitedExpression = base.Visit(expression);
+
+ throw new HibernateException($"Evaluation failure on {baseVisitedExpression}", ex);
+ }
+
+ if (evaluatedExpression != expression)
+ {
+ evaluatedExpression = EvaluateIndependentSubtrees(evaluatedExpression, _preTransformationParameters);
+ }
+
+ #region NH additions
+
+ // When having multiple level closure, we have to evaluate each closure independently
+ if (evaluatedExpression is ConstantExpression constantExpression)
+ {
+ evaluatedExpression = VisitConstant(constantExpression);
+ }
+
+ // Variables in expressions are never a constant, they are encapsulated as fields of a compiler generated class.
+ if (expression.NodeType != ExpressionType.Constant &&
+ _preTransformationParameters.MinimizeParameters &&
+ evaluatedExpression is ConstantExpression variableConstant &&
+ !_preTransformationParameters.QueryVariables.ContainsKey(variableConstant) &&
+ ExpressionsHelper.IsVariable(expression, out var path, out var closureContext))
+ {
+ _preTransformationParameters.QueryVariables.Add(variableConstant, new QueryVariable(path, closureContext));
+ }
+
+ #endregion
+
+ return evaluatedExpression;
+ }
+
+ ///
+ /// Evaluates an evaluatable subtree, i.e. an independent expression tree that is compilable and executable
+ /// without any data being passed in. The result of the evaluation is returned as a ; if the subtree
+ /// is already a , no evaluation is performed.
+ ///
+ /// The subtree to be evaluated.
+ /// A holding the result of the evaluation.
+ private Expression EvaluateSubtree(Expression subtree)
+ {
+ if (subtree.NodeType == ExpressionType.Constant)
+ {
+ var constantExpression = (ConstantExpression) subtree;
+ var valueAsIQueryable = constantExpression.Value as IQueryable;
+ if (valueAsIQueryable != null && valueAsIQueryable.Expression != constantExpression)
+ return valueAsIQueryable.Expression;
+
+ return constantExpression;
+ }
+ else
+ {
+ Expression> lambdaWithoutParameters = Expression.Lambda>(Expression.Convert(subtree, typeof(object)));
+ var compiledLambda = lambdaWithoutParameters.Compile();
+
+ object value = compiledLambda();
+ return Expression.Constant(value, subtree.Type);
+ }
+ }
+
+ #endregion
+
protected override Expression VisitConstant(ConstantExpression expression)
{
if (expression.Value is Expression value)
{
- return EvaluateIndependentSubtrees(value, _sessionFactory);
+ return EvaluateIndependentSubtrees(value, _preTransformationParameters);
}
-
return base.VisitConstant(expression);
}
+ }
- public static Expression EvaluateIndependentSubtrees(
- Expression expression,
- ISessionFactoryImplementor sessionFactory)
+ internal struct QueryVariable : IEquatable
+ {
+ public QueryVariable(string path, object closureContext)
+ {
+ Path = path;
+ ClosureContext = closureContext;
+ }
+
+ public string Path { get; }
+
+ public object ClosureContext { get; }
+
+ public override bool Equals(object obj)
+ {
+ return obj is QueryVariable other && Equals(other);
+ }
+
+ public override int GetHashCode()
{
- var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(
- expression,
- new NhEvaluatableExpressionFilter(sessionFactory));
- return new NhPartialEvaluatingExpressionVisitor(sessionFactory).Visit(evaluatedExpression);
+ unchecked
+ {
+ return (Path.GetHashCode() * 397) ^ ClosureContext.GetHashCode();
+ }
}
- public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpression partialEvaluationExceptionExpression)
+ public bool Equals(QueryVariable other)
{
- throw new HibernateException(
- $"Evaluation failure on {partialEvaluationExceptionExpression.EvaluatedExpression}",
- partialEvaluationExceptionExpression.Exception);
+ return Path == other.Path && ReferenceEquals(ClosureContext, other.ClosureContext);
}
}
@@ -68,6 +211,11 @@ public override bool IsEvaluatableConstant(ConstantExpression node)
return base.IsEvaluatableConstant(node);
}
+ public override bool IsEvaluatableUnary(UnaryExpression node)
+ {
+ return !ExpressionsHelper.IsVariable(node.Operand, out _, out _);
+ }
+
public override bool IsEvaluatableMember(MemberExpression node)
{
if (node == null)
diff --git a/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs b/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs
new file mode 100644
index 00000000000..3f582eef531
--- /dev/null
+++ b/src/NHibernate/Linq/Visitors/PreTransformationParameters.cs
@@ -0,0 +1,51 @@
+using System.Collections.Generic;
+using System.Linq.Expressions;
+using NHibernate.Engine;
+using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation;
+
+namespace NHibernate.Linq.Visitors
+{
+ ///
+ /// Contains the information needed by to perform an early transformation.
+ ///
+ public class PreTransformationParameters
+ {
+ ///
+ /// The default constructor.
+ ///
+ /// The query mode of the expression to pre-transform.
+ /// The session factory used in the pre-transform process.
+ public PreTransformationParameters(QueryMode queryMode, ISessionFactoryImplementor sessionFactory)
+ {
+ QueryMode = queryMode;
+ SessionFactory = sessionFactory;
+ // Skip detecting variables for DML queries as HQL does not support reusing parameters for them.
+ MinimizeParameters = QueryMode == QueryMode.Select;
+ }
+
+ ///
+ /// The query mode of the expression to pre-transform.
+ ///
+ public QueryMode QueryMode { get; }
+
+ ///
+ /// The session factory used in the pre-transform process.
+ ///
+ public ISessionFactoryImplementor SessionFactory { get; }
+
+ ///
+ /// Whether to minimize the number of parameters for variables.
+ ///
+ public bool MinimizeParameters { get; set; }
+
+ ///
+ /// The filter which decides whether a part of the expression will be pre-evalauted or not.
+ ///
+ internal IEvaluatableExpressionFilter EvaluatableExpressionFilter { get; set; }
+
+ ///
+ /// A dictionary of that were evaluated from variables.
+ ///
+ internal IDictionary QueryVariables { get; set; }
+ }
+}
diff --git a/src/NHibernate/Linq/Visitors/PreTransformationResult.cs b/src/NHibernate/Linq/Visitors/PreTransformationResult.cs
new file mode 100644
index 00000000000..6f55ddc7bba
--- /dev/null
+++ b/src/NHibernate/Linq/Visitors/PreTransformationResult.cs
@@ -0,0 +1,37 @@
+using System.Collections.Generic;
+using System.Linq.Expressions;
+using NHibernate.Engine;
+
+namespace NHibernate.Linq.Visitors
+{
+ ///
+ /// The result of method.
+ ///
+ public class PreTransformationResult
+ {
+ internal PreTransformationResult(
+ Expression expression,
+ ISessionFactoryImplementor sessionFactory,
+ IDictionary queryVariables)
+ {
+ Expression = expression;
+ SessionFactory = sessionFactory;
+ QueryVariables = queryVariables;
+ }
+
+ ///
+ /// The transformed expression.
+ ///
+ public Expression Expression { get; }
+
+ ///
+ /// The session factory used in the pre-transform process.
+ ///
+ public ISessionFactoryImplementor SessionFactory { get; }
+
+ ///
+ /// A dictionary of that were evaluated from variables.
+ ///
+ internal IDictionary QueryVariables { get; }
+ }
+}
diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs
index 17fc7850425..169b1211eb5 100644
--- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs
+++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessContains.cs
@@ -59,8 +59,9 @@ private static HqlAlias GetFromAlias(HqlTreeNode node)
private static bool IsEmptyList(HqlParameter source, VisitorParameters parameters)
{
var parameterName = source.NodesPreOrder.Single(n => n is HqlIdent).AstNode.Text;
- var parameterValue = parameters.ConstantToParameterMap.Single(p => p.Value.Name == parameterName).Key.Value;
+ // Multiple constants may be linked to the same parameter, take the first matching parameter
+ var parameterValue = parameters.ConstantToParameterMap.First(p => p.Value.Name == parameterName).Key.Value;
return !((IEnumerable)parameterValue).Cast