Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/NHibernate.Test/Async/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async()
Assert.AreEqual(expectedCount, query.Count);
}

[Test]
public async Task CanQueryWithContainsOnEnumStoredAsString_Small_1Async()
{
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
var query = await (db.Users.Where(x => values.Contains(x.Enum1)).ToListAsync());
Assert.AreEqual(3, query.Count);
}

[Test]
public async Task ConditionalNavigationPropertyAsync()
{
Expand Down
8 changes: 8 additions & 0 deletions src/NHibernate.Test/Linq/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo
Assert.AreEqual(expectedCount, query.Count);
}

[Test]
public void CanQueryWithContainsOnEnumStoredAsString_Small_1()
{
var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium };
var query = db.Users.Where(x => values.Contains(x.Enum1)).ToList();
Assert.AreEqual(3, query.Count);
}

[Test]
public void ConditionalNavigationProperty()
{
Expand Down
16 changes: 16 additions & 0 deletions src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ public void EqualStringEnumTest()
);
}

[Test]
public void ContainsStringEnumTest()
{
var values = new[] {EnumStoredAsString.Small};
AssertResults(
new Dictionary<string, Predicate<IType>>
{
{"value(NHibernate.DomainModel.Northwind.Entities.EnumStoredAsString[])", o => o is EnumStoredAsStringType}
},
db.Users.Where(o => values.Contains(o.Enum1)),
db.Users.Where(o => values.Contains(o.NullableEnum1.Value)),
db.Users.Where(o => values.Contains(o.Name == o.Name ? o.Enum1 : o.NullableEnum1.Value)),
db.Timesheets.Where(o => o.Users.Any(u => values.Contains(u.Enum1)))
);
}

[Test]
public void EqualStringEnumTestWithFetch()
{
Expand Down
32 changes: 28 additions & 4 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
using System.Collections.Generic;
using System.Dynamic;
using System.Linq;
using System.Linq.Expressions;
using NHibernate.Engine;
using NHibernate.Param;
using NHibernate.Type;
using NHibernate.Util;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Parsing;

namespace NHibernate.Linq.Visitors
Expand Down Expand Up @@ -219,14 +222,35 @@ protected override Expression VisitConstant(ConstantExpression node)
return node;
}

public override Expression Visit(Expression node)
protected override Expression VisitSubQuery(SubQueryExpression node)
{
if (node is SubQueryExpression subQueryExpression)
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
// ContainsResultOperator where the constant expression is dislocated from the related expression,
// we have to manually link the related expressions.
var containsOperator = node.QueryModel.ResultOperators.OfType<ContainsResultOperator>().FirstOrDefault();
if (containsOperator != null &&
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
mainFromClause.FromExpression is ConstantExpression constantExpression)
{
subQueryExpression.QueryModel.TransformExpressions(Visit);
VisitConstant(constantExpression);
AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item)));
// Copy all found MemberExpressions to the constant expression
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
{
foreach (var nestedMemberExpression in set)
{
AddRelatedExpression(constantExpression, nestedMemberExpression);
}
}
}
else
{
node.QueryModel.TransformExpressions(Visit);
}

return base.Visit(node);
return node;
}

private void VisitAssign(Expression leftNode, Expression rightNode)
Expand Down