diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 6e9355d294c..e08a2c90829 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -62,6 +62,14 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + [Test] + public async Task CanQueryWithContainsOnEnumStoredAsString_Small_1Async() + { + var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium }; + var query = await (db.Users.Where(x => values.Contains(x.Enum1)).ToListAsync()); + Assert.AreEqual(3, query.Count); + } + [Test] public async Task ConditionalNavigationPropertyAsync() { diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index aeea060b51e..7f312de7e42 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -49,6 +49,14 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + [Test] + public void CanQueryWithContainsOnEnumStoredAsString_Small_1() + { + var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium }; + var query = db.Users.Where(x => values.Contains(x.Enum1)).ToList(); + Assert.AreEqual(3, query.Count); + } + [Test] public void ConditionalNavigationProperty() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index 511e23f88cd..39cb2d22d74 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -84,6 +84,22 @@ public void EqualStringEnumTest() ); } + [Test] + public void ContainsStringEnumTest() + { + var values = new[] {EnumStoredAsString.Small}; + AssertResults( + new Dictionary> + { + {"value(NHibernate.DomainModel.Northwind.Entities.EnumStoredAsString[])", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => values.Contains(o.Enum1)), + db.Users.Where(o => values.Contains(o.NullableEnum1.Value)), + db.Users.Where(o => values.Contains(o.Name == o.Name ? o.Enum1 : o.NullableEnum1.Value)), + db.Timesheets.Where(o => o.Users.Any(u => values.Contains(u.Enum1))) + ); + } + [Test] public void EqualStringEnumTestWithFetch() { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 34326640169..40f3ec0d3d3 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -1,12 +1,15 @@ using System.Collections.Generic; using System.Dynamic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; +using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -219,14 +222,35 @@ protected override Expression VisitConstant(ConstantExpression node) return node; } - public override Expression Visit(Expression node) + protected override Expression VisitSubQuery(SubQueryExpression node) { - if (node is SubQueryExpression subQueryExpression) + // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of + // ContainsResultOperator where the constant expression is dislocated from the related expression, + // we have to manually link the related expressions. + if (node.QueryModel.ResultOperators.Count == 1 && + node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator && + node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference && + querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && + mainFromClause.FromExpression is ConstantExpression constantExpression) { - subQueryExpression.QueryModel.TransformExpressions(Visit); + VisitConstant(constantExpression); + AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item))); + // Copy all found MemberExpressions to the constant expression + // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) + if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set)) + { + foreach (var nestedMemberExpression in set) + { + AddRelatedExpression(constantExpression, nestedMemberExpression); + } + } + } + else + { + node.QueryModel.TransformExpressions(Visit); } - return base.Visit(node); + return node; } private void VisitAssign(Expression leftNode, Expression rightNode)