diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByHavingTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByHavingTests.cs index 1921fdf0bff..c7b618793e4 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByHavingTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByHavingTests.cs @@ -10,6 +10,7 @@ using System; using System.Linq; +using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; using NHibernate.Linq; @@ -147,5 +148,45 @@ public async Task SingleKeyGroupAndCountWithHavingClauseAsync() var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn"); Assert.That(hornRow.OrderCount, Is.EqualTo(13)); } + + [Test] + public async Task HavingWithStringEnumParameterAsync() + { + await (db.Users + .GroupBy(p => p.Enum1) + .Where(g => g.Key == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + await (db.Users + .GroupBy(p => new StringEnumGroup {Enum = p.Enum1}) + .Where(g => g.Key.Enum == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + await (db.Users + .GroupBy(p => new[] {p.Enum1}) + .Where(g => g.Key[0] == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {p.Enum1}) + .Where(g => g.Key.Enum1 == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {Test = new {Test2 = p.Enum1}}) + .Where(g => g.Key.Test.Test2 == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {Test = new[] {p.Enum1}}) + .Where(g => g.Key.Test[0] == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToListAsync()); + } + + private class StringEnumGroup + { + public EnumStoredAsString Enum { get; set; } + } } } diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs index a672dd18381..f23b8f2d14d 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs @@ -382,6 +382,40 @@ public async Task GroupByAndAnyAsync() Assert.That(namesAreNotEmpty, Is.True); } + [Test] + public async Task GroupByWithStringEnumParameterAsync() + { + await (db.Users + .GroupBy(p => p.Enum1) + .Select(g => g.Key == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + await (db.Users + .GroupBy(p => new StringEnumGroup {Enum = p.Enum1}) + .Select(g => g.Key.Enum == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + await (db.Users + .GroupBy(p => new[] {p.Enum1}) + .Select(g => g.Key[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {p.Enum1}) + .Select(g => g.Key.Enum1 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {Test = new {Test2 = p.Enum1}}) + .Select(g => g.Key.Test.Test2 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + await (db.Users + .GroupBy(p => new {Test = new[] {p.Enum1}}) + .Select(g => g.Key.Test[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToListAsync()); + } + + private class StringEnumGroup + { + public EnumStoredAsString Enum { get; set; } + } + [Test] public async Task SelectFirstElementFromProductsGroupedByUnitPriceAsync() { diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs index 818d7593722..7e6258bfcaf 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -135,5 +136,45 @@ public void SingleKeyGroupAndCountWithHavingClause() var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn"); Assert.That(hornRow.OrderCount, Is.EqualTo(13)); } + + [Test] + public void HavingWithStringEnumParameter() + { + db.Users + .GroupBy(p => p.Enum1) + .Where(g => g.Key == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + db.Users + .GroupBy(p => new StringEnumGroup {Enum = p.Enum1}) + .Where(g => g.Key.Enum == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + db.Users + .GroupBy(p => new[] {p.Enum1}) + .Where(g => g.Key[0] == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + db.Users + .GroupBy(p => new {p.Enum1}) + .Where(g => g.Key.Enum1 == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + db.Users + .GroupBy(p => new {Test = new {Test2 = p.Enum1}}) + .Where(g => g.Key.Test.Test2 == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + db.Users + .GroupBy(p => new {Test = new[] {p.Enum1}}) + .Where(g => g.Key.Test[0] == EnumStoredAsString.Large) + .Select(g => g.Count()) + .ToList(); + } + + private class StringEnumGroup + { + public EnumStoredAsString Enum { get; set; } + } } } diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index 626dc619692..2b5bab7bcab 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -371,6 +371,40 @@ public void GroupByAndAny() Assert.That(namesAreNotEmpty, Is.True); } + [Test] + public void GroupByWithStringEnumParameter() + { + db.Users + .GroupBy(p => p.Enum1) + .Select(g => g.Key == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + db.Users + .GroupBy(p => new StringEnumGroup {Enum = p.Enum1}) + .Select(g => g.Key.Enum == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + db.Users + .GroupBy(p => new[] {p.Enum1}) + .Select(g => g.Key[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + db.Users + .GroupBy(p => new {p.Enum1}) + .Select(g => g.Key.Enum1 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + db.Users + .GroupBy(p => new {Test = new {Test2 = p.Enum1}}) + .Select(g => g.Key.Test.Test2 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + db.Users + .GroupBy(p => new {Test = new[] {p.Enum1}}) + .Select(g => g.Key.Test[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0) + .ToList(); + } + + private class StringEnumGroup + { + public EnumStoredAsString Enum { get; set; } + } + [Test] public void SelectFirstElementFromProductsGroupedByUnitPrice() { diff --git a/src/NHibernate/Linq/ExpressionExtensions.cs b/src/NHibernate/Linq/ExpressionExtensions.cs index 8c84ccfe302..4b9dccdd993 100644 --- a/src/NHibernate/Linq/ExpressionExtensions.cs +++ b/src/NHibernate/Linq/ExpressionExtensions.cs @@ -15,6 +15,23 @@ public static bool IsGroupingKey(this MemberExpression expression) expression.Member.DeclaringType.IsGenericType && expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof(IGrouping<,>); } + internal static bool TryGetGroupResultOperator(this MemberExpression keyExpression, out GroupResultOperator groupBy) + { + if (keyExpression.IsGroupingKey() && + keyExpression.Expression is QuerySourceReferenceExpression querySource && + querySource.ReferencedQuerySource is MainFromClause fromClause && + fromClause.FromExpression is SubQueryExpression query) + { + groupBy = query.QueryModel.ResultOperators + .OfType() + .FirstOrDefault(o => o.KeySelector.Type == keyExpression.Type); + return groupBy != null; + } + + groupBy = null; + return false; + } + public static bool IsGroupingKeyOf(this MemberExpression expression,GroupResultOperator groupBy) { if (!expression.IsGroupingKey()) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 9cab3d6cd26..7e7d5834272 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -343,6 +343,7 @@ private void VisitAssign(Expression leftNode, Expression rightNode) private void AddRelatedExpression(Expression node, Expression left, Expression right) { if (left.NodeType == ExpressionType.MemberAccess || + left.NodeType == ExpressionType.ArrayIndex || // e.g. group.Key[0] == variable IsDynamicMember(left) || left is QuerySourceReferenceExpression) { diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index ee6388a9fa2..221b8031266 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -17,7 +17,7 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; -using Remotion.Linq.Parsing; +using TransparentIdentifierRemovingExpressionVisitor = NHibernate.Linq.Visitors.TransparentIdentifierRemovingExpressionVisitor; namespace NHibernate.Util { @@ -594,6 +594,43 @@ private static IType GetType( : TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToMany[0] } + private class GroupingKeyFlattener : NhExpressionVisitor + { + private bool _flattened; + + public static Expression FlattenGroupingKey(Expression expression) + { + var visitor = new GroupingKeyFlattener(); + expression = visitor.Visit(expression); + if (visitor._flattened) + { + expression = TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(expression); + // When the grouping key is an array we have to unwrap it (e.g. group.Key[0] == variable) + if (expression.NodeType == ExpressionType.ArrayIndex && + expression is BinaryExpression binaryExpression && + binaryExpression.Left is NewArrayExpression newArray && + binaryExpression.Right is ConstantExpression indexExpression && + indexExpression.Value is int index) + { + return newArray.Expressions[index]; + } + } + + return expression; + } + + protected override Expression VisitMember(MemberExpression node) + { + if (node.TryGetGroupResultOperator(out var groupBy)) + { + _flattened = true; + return groupBy.KeySelector; + } + + return base.VisitMember(node); + } + } + private class MemberMetadataExtractor : NhExpressionVisitor { private readonly List _childrenResults = new List(); @@ -639,6 +676,8 @@ private static bool TryGetAllMemberMetadata( bool hasIndexer, out MemberMetadataResult results) { + expression = GroupingKeyFlattener.FlattenGroupingKey(expression); + var extractor = new MemberMetadataExtractor(memberPaths, convertType, hasIndexer); extractor.Accept(expression); results = extractor._entityName != null || extractor._childrenResults.Count > 0