Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/NHibernate.Test/Async/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async()
Assert.AreEqual(expectedCount, query.Count);
}

[Test]
public async Task CanQueryWithContainsOnEnumStoredAsString_Small_1Async()
{
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
var query = await (db.Users.Where(x => values.Contains(x.Enum1)).ToListAsync());
Assert.AreEqual(3, query.Count);
}

[Test]
public async Task ConditionalNavigationPropertyAsync()
{
Expand Down
8 changes: 8 additions & 0 deletions src/NHibernate.Test/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo
Assert.AreEqual(expectedCount, query.Count);
}

[Test]
public void CanQueryWithContainsOnEnumStoredAsString_Small_1()
{
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
var query = db.Users.Where(x => values.Contains(x.Enum1)).ToList();
Assert.AreEqual(3, query.Count);
}

[Test]
public void ConditionalNavigationProperty()
{
Expand Down
16 changes: 16 additions & 0 deletions src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ public void EqualStringEnumTest()
);
}

[Test]
public void ContainsStringEnumTest()
{
var values = new[] {EnumStoredAsString.Small};
AssertResults(
new Dictionary<string, Predicate<IType>>
{
{"value(NHibernate.DomainModel.Northwind.Entities.EnumStoredAsString[])", o => o is EnumStoredAsStringType}
},
db.Users.Where(o => values.Contains(o.Enum1)),
db.Users.Where(o => values.Contains(o.NullableEnum1.Value)),
db.Users.Where(o => values.Contains(o.Name == o.Name ? o.Enum1 : o.NullableEnum1.Value)),
db.Timesheets.Where(o => o.Users.Any(u => values.Contains(u.Enum1)))
);
}

[Test]
public void EqualStringEnumTestWithFetch()
{
Expand Down
32 changes: 28 additions & 4 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
using System.Collections.Generic;
using System.Dynamic;
using System.Linq;
using System.Linq.Expressions;
using NHibernate.Engine;
using NHibernate.Param;
using NHibernate.Type;
using NHibernate.Util;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Parsing;

namespace NHibernate.Linq.Visitors
Expand Down Expand Up @@ -219,14 +222,35 @@ protected override Expression VisitConstant(ConstantExpression node)
return node;
}

public override Expression Visit(Expression node)
protected override Expression VisitSubQuery(SubQueryExpression node)
{
if (node is SubQueryExpression subQueryExpression)
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
// ContainsResultOperator where the constant expression is dislocated from the related expression,
// we have to manually link the related expressions.
if (node.QueryModel.ResultOperators.Count == 1 &&
node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator &&
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
mainFromClause.FromExpression is ConstantExpression constantExpression)
{
subQueryExpression.QueryModel.TransformExpressions(Visit);
VisitConstant(constantExpression);
AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item)));
// Copy all found MemberExpressions to the constant expression
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
{
foreach (var nestedMemberExpression in set)
{
AddRelatedExpression(constantExpression, nestedMemberExpression);
}
}
}
else
{
node.QueryModel.TransformExpressions(Visit);
}

return base.Visit(node);
return node;
}

private void VisitAssign(Expression leftNode, Expression rightNode)
Expand Down