diff --git a/src/NHibernate.Test/Async/Linq/WhereTests.cs b/src/NHibernate.Test/Async/Linq/WhereTests.cs index 56d183e49c9..5d8a936b342 100644 --- a/src/NHibernate.Test/Async/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Async/Linq/WhereTests.cs @@ -644,6 +644,17 @@ where sheet.Users.Contains(user) Assert.That(query.Count, Is.EqualTo(2)); } + [Test] + public async Task TimesheetsWithEnumerableContainsOnSelectAsync() + { + var value = (EnumStoredAsInt32) 1000; + var query = await ((from sheet in db.Timesheets + where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) + select sheet).ToListAsync()); + + Assert.That(query.Count, Is.EqualTo(1)); + } + [Test] public async Task SearchOnObjectTypeWithExtensionMethodAsync() { diff --git a/src/NHibernate.Test/Linq/WhereTests.cs b/src/NHibernate.Test/Linq/WhereTests.cs index 02dc58b34b7..0b4b5da6575 100644 --- a/src/NHibernate.Test/Linq/WhereTests.cs +++ b/src/NHibernate.Test/Linq/WhereTests.cs @@ -645,6 +645,17 @@ where sheet.Users.Contains(user) Assert.That(query.Count, Is.EqualTo(2)); } + [Test] + public void TimesheetsWithEnumerableContainsOnSelect() + { + var value = (EnumStoredAsInt32) 1000; + var query = (from sheet in db.Timesheets + where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value) + select sheet).ToList(); + + Assert.That(query.Count, Is.EqualTo(1)); + } + [Test] public void SearchOnObjectTypeWithExtensionMethod() { diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 3c95d90c9ab..c9f2a054bb1 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -114,6 +114,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer private readonly NhLinqExpressionReturnType? _rootReturnType; private static readonly ResultOperatorMap ResultOperatorMap; private bool _serverSide = true; + private readonly bool _root; public VisitorParameters VisitorParameters { get; } @@ -161,6 +162,7 @@ private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryM _queryMode = root ? visitorParameters.RootQueryMode : QueryMode.Select; VisitorParameters = visitorParameters; Model = queryModel; + _root = root; _rootReturnType = root ? rootReturnType : null; _hqlTree = new IntermediateHqlTree(root, _queryMode); } @@ -467,19 +469,27 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que } //This is a standard select query + _hqlTree.AddSelectClause(GetSelectClause(selectClause.Selector)); + + base.VisitSelectClause(selectClause, queryModel); + } + + private HqlSelect GetSelectClause(Expression selectClause) + { + if (!_root) + return _hqlTree.TreeBuilder.Select( + HqlGeneratorExpressionVisitor.Visit(selectClause, VisitorParameters).AsExpression()); var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); - visitor.VisitSelector(selectClause.Selector); + visitor.VisitSelector(selectClause); if (visitor.ProjectionExpression != null) { _hqlTree.AddItemTransformer(visitor.ProjectionExpression); } - _hqlTree.AddSelectClause(_hqlTree.TreeBuilder.Select(visitor.GetHqlNodes())); - - base.VisitSelectClause(selectClause, queryModel); + return _hqlTree.TreeBuilder.Select(visitor.GetHqlNodes()); } private void VisitInsertClause(Expression expression) @@ -527,6 +537,9 @@ private void VisitUpdateClause(Expression expression) private void VisitDeleteClause(Expression expression) { + if (!_root) + return; + // We only need to check there is no unexpected select, for avoiding silently ignoring them. var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters); visitor.VisitSelector(expression);