diff --git a/src/NHibernate.Test/Async/Linq/WhereTests.cs b/src/NHibernate.Test/Async/Linq/WhereTests.cs index 5d8a936b342..aceda381352 100644 --- a/src/NHibernate.Test/Async/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Async/Linq/WhereTests.cs @@ -15,6 +15,7 @@ using System.Linq; using System.Linq.Expressions; using log4net.Core; +using NHibernate.Dialect; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; @@ -647,6 +648,9 @@ where sheet.Users.Contains(user) [Test] public async Task TimesheetsWithEnumerableContainsOnSelectAsync() { + if (Dialect is MsSqlCeDialect) + Assert.Ignore("Dialect is not supported"); + var value = (EnumStoredAsInt32) 1000; var query = await ((from sheet in db.Timesheets where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) @@ -655,6 +659,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) Assert.That(query.Count, Is.EqualTo(1)); } + [Test] + public async Task ContainsSubqueryWithCoalesceStringEnumSelectAsync() + { + if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect) + Assert.Ignore("Dialect is not supported"); + + var results = + await (db.Timesheets.Where( + o => + o.Users + .Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32)) + .Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value) + .Contains(EnumStoredAsString.Small)) + .ToListAsync()); + + Assert.That(results.Count, Is.EqualTo(1)); + } + [Test] public async Task SearchOnObjectTypeWithExtensionMethodAsync() { diff --git a/src/NHibernate.Test/Linq/WhereTests.cs b/src/NHibernate.Test/Linq/WhereTests.cs index 0b4b5da6575..5fffc56c052 100644 --- a/src/NHibernate.Test/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Linq/WhereTests.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Linq.Expressions; using log4net.Core; +using NHibernate.Dialect; using NHibernate.Engine.Query; using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; @@ -648,6 +649,9 @@ where sheet.Users.Contains(user) [Test] public void TimesheetsWithEnumerableContainsOnSelect() { + if (Dialect is MsSqlCeDialect) + Assert.Ignore("Dialect is not supported"); + var value = (EnumStoredAsInt32) 1000; var query = (from sheet in db.Timesheets where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) @@ -656,6 +660,24 @@ where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) Assert.That(query.Count, Is.EqualTo(1)); } + [Test] + public void ContainsSubqueryWithCoalesceStringEnumSelect() + { + if (Dialect is MsSqlCeDialect || Dialect is SQLiteDialect) + Assert.Ignore("Dialect is not supported"); + + var results = + db.Timesheets.Where( + o => + o.Users + .Where(u => u.Id != 0.MappedAs(NHibernateUtil.Int32)) + .Select(u => u.Name == u.Name ? u.Enum1 : u.NullableEnum1.Value) + .Contains(EnumStoredAsString.Small)) + .ToList(); + + Assert.That(results.Count, Is.EqualTo(1)); + } + [Test] public void SearchOnObjectTypeWithExtensionMethod() { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index eeb458568d0..3f5b37eab03 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -288,42 +288,35 @@ protected override Expression VisitConstant(ConstantExpression node) protected override Expression VisitSubQuery(SubQueryExpression node) { - if (!TryLinkContainsMethod(node.QueryModel)) - { - node.QueryModel.TransformExpressions(Visit); - } + TryLinkContainsMethod(node.QueryModel); + node.QueryModel.TransformExpressions(Visit); return node; } - private bool TryLinkContainsMethod(QueryModel queryModel) + private void TryLinkContainsMethod(QueryModel queryModel) { // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of // ContainsResultOperator where the constant expression is dislocated from the related expression, // we have to manually link the related expressions. if (queryModel.ResultOperators.Count != 1 || - !(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) || - !(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) || - !(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause)) + !(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator)) { - return false; + return; } - var left = UnwrapUnary(Visit(mainFromClause.FromExpression)); + Expression selector = + queryModel.SelectClause.Selector is QuerySourceReferenceExpression { ReferencedQuerySource: MainFromClause mainFromClause } + ? mainFromClause.FromExpression + : queryModel.SelectClause.Selector; + + var left = UnwrapUnary(Visit(selector)); var right = UnwrapUnary(Visit(containsOperator.Item)); - // The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o))) - // The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item))) - if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant) - { - return false; - } // Copy all found MemberExpressions to the constant expression // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) AddRelatedExpression(null, left, right); AddRelatedExpression(null, right, left); - - return true; } private void VisitAssign(Expression leftNode, Expression rightNode)