Skip to content

Commit 5a1df86

Browse files
committed
NH-3816 - Nominate GroupBy key expressions directly and resolve only those to the database
* Similar intent to the original solution for NH-3797 * Nominations are done via a custom expression type so that join re-writing can still occur on the nominated expressions
1 parent 9e6d66b commit 5a1df86

12 files changed

+256
-43
lines changed

src/NHibernate.Test/Linq/JoinTests.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,13 @@ public void OrderLinesWithSelectingCustomerIdInCaseShouldProduceOneJoin()
273273
}
274274
}
275275

276-
[Test(Description = "NH-3801")]
276+
[Test(Description = "NH-3801"), Ignore("This is an ideal case, but not possible without better join detection")]
277277
public void OrderLinesWithSelectingCustomerInCaseShouldProduceOneJoin()
278278
{
279279
using (var spy = new SqlLogSpy())
280280
{
281+
// Without nominating the conditional to the select clause (and placing it in SQL)
282+
// [l.Order.Customer] will be selected in its entirety, creating a second join
281283
(from l in db.OrderLines
282284
select new { CustomerKnown = l.Order.Customer == null ? 0 : 1, l.Order.OrderDate }).ToList();
283285

@@ -299,11 +301,13 @@ public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoins()
299301
}
300302
}
301303

302-
[Test(Description = "NH-3801")]
304+
[Test(Description = "NH-3801"), Ignore("This is an ideal case, but not possible without better join detection")]
303305
public void OrderLinesWithSelectingCustomerNameInCaseShouldProduceTwoJoinsAlternate()
304306
{
305307
using (var spy = new SqlLogSpy())
306308
{
309+
// Without nominating the conditional to the select clause (and placing it in SQL)
310+
// [l.Order.Customer] will be selected in its entirety, creating a second join
307311
(from l in db.OrderLines
308312
select new { CustomerKnown = l.Order.Customer == null ? "unknown" : l.Order.Customer.CompanyName, l.Order.OrderDate }).ToList();
309313

src/NHibernate.Test/Linq/SelectionTests.cs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,76 @@ public void CanSelectCollection()
387387
Assert.AreEqual(5, orders[0].Count);
388388
}
389389

390+
[Test]
391+
public void CanSelectConditionalKnownTypes()
392+
{
393+
var moreThanTwoOrderLinesBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : false }).ToList();
394+
Assert.That(moreThanTwoOrderLinesBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410));
395+
396+
var moreThanTwoOrderLinesNBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : (bool?)null }).ToList();
397+
Assert.That(moreThanTwoOrderLinesNBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410));
398+
399+
var moreThanTwoOrderLinesShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short)1 : (short)0 }).ToList();
400+
Assert.That(moreThanTwoOrderLinesShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410));
401+
402+
var moreThanTwoOrderLinesNShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short?)1 : (short?)null }).ToList();
403+
Assert.That(moreThanTwoOrderLinesNShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410));
404+
405+
var moreThanTwoOrderLinesInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : 0 }).ToList();
406+
Assert.That(moreThanTwoOrderLinesInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410));
407+
408+
var moreThanTwoOrderLinesNInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : (int?)null }).ToList();
409+
Assert.That(moreThanTwoOrderLinesNInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410));
410+
411+
var moreThanTwoOrderLinesDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : 0m }).ToList();
412+
Assert.That(moreThanTwoOrderLinesDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410));
413+
414+
var moreThanTwoOrderLinesNDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : (decimal?)null }).ToList();
415+
Assert.That(moreThanTwoOrderLinesNDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410));
416+
417+
var moreThanTwoOrderLinesSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : 0f }).ToList();
418+
Assert.That(moreThanTwoOrderLinesSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410));
419+
420+
var moreThanTwoOrderLinesNSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : (float?)null }).ToList();
421+
Assert.That(moreThanTwoOrderLinesNSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410));
422+
423+
var moreThanTwoOrderLinesDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : 0d }).ToList();
424+
Assert.That(moreThanTwoOrderLinesDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410));
425+
426+
var moreThanTwoOrderLinesNDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : (double?)null }).ToList();
427+
Assert.That(moreThanTwoOrderLinesNDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410));
428+
429+
var moreThanTwoOrderLinesString = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? "yes" : "no" }).ToList();
430+
Assert.That(moreThanTwoOrderLinesString.Count(x => x.HasMoreThanTwo == "yes"), Is.EqualTo(410));
431+
432+
var now = DateTime.Now.Date;
433+
var moreThanTwoOrderLinesDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate.Value : now }).ToList();
434+
Assert.That(moreThanTwoOrderLinesDateTime.Count(x => x.HasMoreThanTwo != now), Is.EqualTo(410));
435+
436+
var moreThanTwoOrderLinesNDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate : null }).ToList();
437+
Assert.That(moreThanTwoOrderLinesNDateTime.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410));
438+
439+
var moreThanTwoOrderLinesGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : Guid.Empty }).ToList();
440+
Assert.That(moreThanTwoOrderLinesGuid.Count(x => x.HasMoreThanTwo != Guid.Empty), Is.EqualTo(410));
441+
442+
var moreThanTwoOrderLinesNGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : (Guid?)null }).ToList();
443+
Assert.That(moreThanTwoOrderLinesNGuid.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410));
444+
}
445+
446+
[Test]
447+
public void CanSelectConditionalEntity()
448+
{
449+
var fatherInsteadOfChild = db.Animals.Select(a => a.Father.SerialNumber == "5678" ? a.Father : a).ToList();
450+
Assert.That(fatherInsteadOfChild, Has.Exactly(2).With.Property("SerialNumber").EqualTo("5678"));
451+
}
452+
453+
[Test]
454+
public void CanSelectConditionalObject()
455+
{
456+
var fatherIsKnown = db.Animals.Select(a => new { a.SerialNumber, Superior = a.Father.SerialNumber, FatherIsKnown = a.Father.SerialNumber == "5678" ? (object)true : (object)false }).ToList();
457+
Assert.That(fatherIsKnown, Has.Exactly(1).With.Property("FatherIsKnown").True);
458+
}
459+
390460
public class Wrapper<T>
391461
{
392462
public T item;

src/NHibernate/Linq/Expressions/NhExpressionType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public enum NhExpressionType
99
Count,
1010
Distinct,
1111
New,
12-
Star
12+
Star,
13+
Nominator
1314
}
1415
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System.Linq.Expressions;
2+
using Remotion.Linq.Clauses.Expressions;
3+
using Remotion.Linq.Parsing;
4+
5+
namespace NHibernate.Linq.Expressions
6+
{
7+
/// <summary>
8+
/// Represents an expression that has been nominated for direct inclusion in the SELECT clause.
9+
/// This bypasses the standard nomination process and assumes that the expression can be converted
10+
/// directly to SQL.
11+
/// </summary>
12+
/// <remarks>
13+
/// Used in the nomination of GroupBy key expressions to ensure that matching select clauses
14+
/// are generated the same way.
15+
/// </remarks>
16+
internal class NhNominatedExpression : ExtensionExpression
17+
{
18+
public Expression Expression { get; private set; }
19+
20+
public NhNominatedExpression(Expression expression) : base(expression.Type, (ExpressionType)NhExpressionType.Nominator)
21+
{
22+
Expression = expression;
23+
}
24+
25+
protected override Expression VisitChildren(ExpressionTreeVisitor visitor)
26+
{
27+
var newExpression = visitor.VisitExpression(Expression);
28+
29+
return newExpression != Expression
30+
? new NhNominatedExpression(newExpression)
31+
: this;
32+
}
33+
}
34+
}

src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ public static Expression ReWrite(Expression expression, GroupResultOperator grou
2222

2323
private readonly GroupResultOperator _groupBy;
2424
private readonly QueryModel _model;
25+
private readonly Expression _nominatedKeySelector;
2526

2627
private GroupBySelectClauseRewriter(GroupResultOperator groupBy, QueryModel model)
2728
{
2829
_groupBy = groupBy;
2930
_model = model;
31+
_nominatedKeySelector = GroupKeyNominator.Visit(groupBy);
3032
}
3133

3234
protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
@@ -53,7 +55,8 @@ protected override Expression VisitMemberExpression(MemberExpression expression)
5355

5456
if (expression.IsGroupingKeyOf(_groupBy))
5557
{
56-
return _groupBy.KeySelector;
58+
// If we have referenced the Key, then return the nominated key expression
59+
return _nominatedKeySelector;
5760
}
5861

5962
var elementSelector = _groupBy.ElementSelector;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using System.Linq;
2+
using System.Linq.Expressions;
3+
using NHibernate.Linq.Expressions;
4+
using Remotion.Linq.Clauses.Expressions;
5+
using Remotion.Linq.Clauses.ResultOperators;
6+
using Remotion.Linq.Parsing;
7+
8+
namespace NHibernate.Linq.GroupBy
9+
{
10+
/// <summary>
11+
/// This class nominates sub-expression trees on the GroupBy Key expression
12+
/// for inclusion in the Select clause.
13+
/// </summary>
14+
internal class GroupKeyNominator : ExpressionTreeVisitor
15+
{
16+
private GroupKeyNominator() { }
17+
18+
private bool _requiresRootNomination;
19+
private bool _transformed;
20+
private int _depth;
21+
22+
public static Expression Visit(GroupResultOperator groupBy)
23+
{
24+
return VisitInternal(groupBy.KeySelector);
25+
}
26+
27+
private static Expression VisitInternal(Expression expr)
28+
{
29+
return new GroupKeyNominator().VisitExpression(expr);
30+
}
31+
32+
public override Expression VisitExpression(Expression expression)
33+
{
34+
_depth++;
35+
var expr = base.VisitExpression(expression);
36+
_depth--;
37+
38+
// At the root expression, wrap it in the nominator expression if needed
39+
if (_depth == 0 && !_transformed && _requiresRootNomination)
40+
{
41+
expr = new NhNominatedExpression(expr);
42+
}
43+
return expr;
44+
}
45+
46+
protected override Expression VisitNewArrayExpression(NewArrayExpression expression)
47+
{
48+
_transformed = true;
49+
// Transform each initializer recursively (to allow for nested initializers)
50+
return Expression.NewArrayInit(expression.Type.GetElementType(), expression.Expressions.Select(VisitInternal));
51+
}
52+
53+
protected override Expression VisitNewExpression(NewExpression expression)
54+
{
55+
_transformed = true;
56+
// Transform each initializer recursively (to allow for nested initializers)
57+
return Expression.New(expression.Constructor, expression.Arguments.Select(VisitInternal), expression.Members);
58+
}
59+
60+
protected override Expression VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression)
61+
{
62+
// If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated
63+
_requiresRootNomination = true;
64+
return base.VisitQuerySourceReferenceExpression(expression);
65+
}
66+
67+
protected override Expression VisitSubQueryExpression(SubQueryExpression expression)
68+
{
69+
// If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated
70+
_requiresRootNomination = true;
71+
return base.VisitSubQueryExpression(expression);
72+
}
73+
}
74+
}

src/NHibernate/Linq/GroupResultOperatorExtensions.cs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,19 @@ internal static class GroupResultOperatorExtensions
1212
{
1313
public static IEnumerable<Expression> ExtractKeyExpressions(this GroupResultOperator groupResult)
1414
{
15-
if (groupResult.KeySelector is NewExpression)
16-
return (groupResult.KeySelector as NewExpression).Arguments;
17-
if (groupResult.KeySelector is NewArrayExpression)
18-
return (groupResult.KeySelector as NewArrayExpression).Expressions;
19-
return new [] { groupResult.KeySelector };
15+
return groupResult.KeySelector.ExtractKeyExpressions();
16+
}
17+
18+
private static IEnumerable<Expression> ExtractKeyExpressions(this Expression expr)
19+
{
20+
// Recursively extract key expressions from nested initializers
21+
// --> new object[] { ((object)new object[] { x.A, x.B }), x.C }
22+
// --> x.A, x.B, x.C
23+
if (expr is NewExpression)
24+
return (expr as NewExpression).Arguments.SelectMany(ExtractKeyExpressions);
25+
if (expr is NewArrayExpression)
26+
return (expr as NewArrayExpression).Expressions.SelectMany(ExtractKeyExpressions);
27+
return new[] { expr };
2028
}
2129
}
2230
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ protected HqlTreeNode VisitExpression(Expression expression)
135135
return VisitNhDistinct((NhDistinctExpression) expression);
136136
case NhExpressionType.Star:
137137
return VisitNhStar((NhStarExpression) expression);
138+
case NhExpressionType.Nominator:
139+
return VisitExpression(((NhNominatedExpression) expression).Expression);
138140
//case NhExpressionType.New:
139141
// return VisitNhNew((NhNewExpression)expression);
140142
}

src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections;
22
using System.Collections.Generic;
33
using System.Linq.Expressions;
4+
using NHibernate.Linq.Expressions;
45
using NHibernate.Linq.ReWriters;
56
using Remotion.Linq.Clauses;
67
using Remotion.Linq.Clauses.Expressions;
@@ -19,6 +20,7 @@ internal class MemberExpressionJoinDetector : ExpressionTreeVisitor
1920
private readonly IJoiner _joiner;
2021

2122
private bool _requiresJoinForNonIdentifier;
23+
private bool _preventJoinsInConditionalTest;
2224
private bool _hasIdentifier;
2325
private int _memberExpressionDepth;
2426

@@ -61,7 +63,7 @@ protected override Expression VisitSubQueryExpression(SubQueryExpression express
6163
protected override Expression VisitConditionalExpression(ConditionalExpression expression)
6264
{
6365
var oldRequiresJoinForNonIdentifier = _requiresJoinForNonIdentifier;
64-
_requiresJoinForNonIdentifier = false;
66+
_requiresJoinForNonIdentifier = !_preventJoinsInConditionalTest && _requiresJoinForNonIdentifier;
6567
var newTest = VisitExpression(expression.Test);
6668
_requiresJoinForNonIdentifier = oldRequiresJoinForNonIdentifier;
6769
var newFalse = VisitExpression(expression.IfFalse);
@@ -71,8 +73,21 @@ protected override Expression VisitConditionalExpression(ConditionalExpression e
7173
return expression;
7274
}
7375

76+
protected override Expression VisitExtensionExpression(ExtensionExpression expression)
77+
{
78+
// Nominated expressions need to prevent joins on non-Identifier member expressions
79+
// (for the test expression of conditional expressions only)
80+
// Otherwise an extra join is created and the GroupBy and Select clauses will not match
81+
var old = _preventJoinsInConditionalTest;
82+
_preventJoinsInConditionalTest = (NhExpressionType)expression.NodeType == NhExpressionType.Nominator;
83+
var expr = base.VisitExtensionExpression(expression);
84+
_preventJoinsInConditionalTest = old;
85+
return expr;
86+
}
87+
7488
public void Transform(SelectClause selectClause)
7589
{
90+
// The select clause typically requires joins for non-Identifier member access
7691
_requiresJoinForNonIdentifier = true;
7792
selectClause.TransformExpressions(VisitExpression);
7893
_requiresJoinForNonIdentifier = false;

0 commit comments

Comments
 (0)