Skip to content

Improve LINQ Contains subquery parameter detection #3274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/NHibernate.Test/Async/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using System.Linq;
using System.Linq.Expressions;
using log4net.Core;
using NHibernate.Dialect;
using NHibernate.Engine.Query;
using NHibernate.Linq;
using NHibernate.DomainModel.Northwind.Entities;
Expand Down Expand Up @@ -647,6 +648,9 @@ where sheet.Users.Contains(user)
[Test]
public async Task TimesheetsWithEnumerableContainsOnSelectAsync()
{
if (Dialect is MsSqlCeDialect)
Assert.Ignore("Dialect is not supported");

var value = (EnumStoredAsInt32) 1000;
var query = await ((from sheet in db.Timesheets
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
Expand All @@ -655,6 +659,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
Assert.That(query.Count, Is.EqualTo(1));
}

[Test]
public async Task ContainsSubqueryWithCoalesceStringEnumSelectAsync()
{
if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect)
Assert.Ignore("Dialect is not supported");

var results =
await (db.Timesheets.Where(
o =>
o.Users
.Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32))
.Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value)
.Contains(EnumStoredAsString.Small))
.ToListAsync());

Assert.That(results.Count, Is.EqualTo(1));
}

[Test]
public async Task SearchOnObjectTypeWithExtensionMethodAsync()
{
Expand Down
22 changes: 22 additions & 0 deletions src/NHibernate.Test/Linq/WhereTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Linq.Expressions;
using log4net.Core;
using NHibernate.Dialect;
using NHibernate.Engine.Query;
using NHibernate.Linq;
using NHibernate.DomainModel.Northwind.Entities;
Expand Down Expand Up @@ -648,6 +649,9 @@ where sheet.Users.Contains(user)
[Test]
public void TimesheetsWithEnumerableContainsOnSelect()
{
if (Dialect is MsSqlCeDialect)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes SqlServerCe test for #3271

See: https://teamcity.jetbrains.com/buildConfiguration/bt1175/4111194?buildTab=tests&expandedTest=build%3A%28id%3A4111194%29%2Cid%3A29811

System.IndexOutOfRangeException : Index was outside the bounds of the array.
at System.Data.SqlServerCe.SqlCeCommand.GetQueryParameters(IntPtr pQpCommand)
at System.Data.SqlServerCe.SqlCeCommand.CreateDataBindings()
at System.Data.SqlServerCe.SqlCeCommand.CompileQueryPlan()

Assert.Ignore("Dialect is not supported");

var value = (EnumStoredAsInt32) 1000;
var query = (from sheet in db.Timesheets
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
Expand All @@ -656,6 +660,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
Assert.That(query.Count, Is.EqualTo(1));
}

[Test]
public void ContainsSubqueryWithCoalesceStringEnumSelect()
{
if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect)
Assert.Ignore("Dialect is not supported");

var results =
db.Timesheets.Where(
o =>
o.Users
.Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32))
.Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value)
.Contains(EnumStoredAsString.Small))
.ToList();

Assert.That(results.Count, Is.EqualTo(1));
}

[Test]
public void SearchOnObjectTypeWithExtensionMethod()
{
Expand Down
29 changes: 11 additions & 18 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,42 +288,35 @@ protected override Expression VisitConstant(ConstantExpression node)

protected override Expression VisitSubQuery(SubQueryExpression node)
{
if (!TryLinkContainsMethod(node.QueryModel))
{
node.QueryModel.TransformExpressions(Visit);
}
TryLinkContainsMethod(node.QueryModel);
node.QueryModel.TransformExpressions(Visit);

return node;
}

private bool TryLinkContainsMethod(QueryModel queryModel)
private void TryLinkContainsMethod(QueryModel queryModel)
{
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
// ContainsResultOperator where the constant expression is dislocated from the related expression,
// we have to manually link the related expressions.
if (queryModel.ResultOperators.Count != 1 ||
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
!(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
!(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator))
{
return false;
return;
}

var left = UnwrapUnary(Visit(mainFromClause.FromExpression));
Expression selector =
queryModel.SelectClause.Selector is QuerySourceReferenceExpression { ReferencedQuerySource: MainFromClause mainFromClause }
? mainFromClause.FromExpression
: queryModel.SelectClause.Selector;

var left = UnwrapUnary(Visit(selector));
var right = UnwrapUnary(Visit(containsOperator.Item));
// The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
// The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant)
{
return false;
}

// Copy all found MemberExpressions to the constant expression
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
AddRelatedExpression(null, left, right);
AddRelatedExpression(null, right, left);

return true;
}

private void VisitAssign(Expression leftNode, Expression rightNode)
Expand Down