Skip to content

Commit 79124b9

Browse files
committed
Attempt to fix NH-3423, by simplifying expressions that compare the result of a construction expression with a null constant.
1 parent 2948d6d commit 79124b9

File tree

7 files changed

+145
-19
lines changed

7 files changed

+145
-19
lines changed

src/NHibernate.Test/Linq/WhereSubqueryTests.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,5 +636,45 @@ public void ProductsWithSubqueryReturningStringFirstOrDefaultEq()
636636

637637
Assert.That(result.Count, Is.EqualTo(13));
638638
}
639+
640+
641+
[Test(Description = "NH-3423")]
642+
public void NullComparedToNewExpressionInWhereClause()
643+
{
644+
// Construction will never be equal to null, so the ternary should be collapsed
645+
// to just the IfFalse expression. Without this collapsing, we cannot generate HQL.
646+
647+
var result = db.Products
648+
.Select(p => new {Name = p.Name, Pr2 = new {ReorderLevel = p.ReorderLevel}})
649+
.Where(pr1 => (pr1.Pr2 == null ? (int?) null : pr1.Pr2.ReorderLevel) > 6)
650+
.ToList();
651+
652+
Assert.That(result.Count, Is.EqualTo(45));
653+
}
654+
655+
private class Pr2
656+
{
657+
public int ReorderLevel { get; set; }
658+
}
659+
660+
private class Pr1
661+
{
662+
public string Name { get; set; }
663+
public Pr2 Pr2 { get; set; }
664+
}
665+
666+
[Test(Description = "NH-3423")]
667+
public void NullComparedToMemberInitExpressionInWhereClause()
668+
{
669+
// Construction will never be equal to null, so the ternary should be collapsed
670+
// to just the IfFalse expression. Without this collapsing, we cannot generate HQL.
671+
672+
var result = db.Products
673+
.Select(p => new Pr1 { Name = p.Name, Pr2 = new Pr2 { ReorderLevel = p.ReorderLevel } })
674+
.Where(pr1 => (pr1.Pr2 == null ? (int?)null : pr1.Pr2.ReorderLevel) > 6)
675+
.ToList();
676+
677+
Assert.That(result.Count, Is.EqualTo(45));
678+
}
639679
}
640680
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionTreeVisitor.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,20 +273,14 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
273273
throw new InvalidOperationException();
274274
}
275275

276+
276277
private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression, HqlExpression lhs, HqlExpression rhs, Func<HqlExpression, HqlTreeNode> applyNullComparison, Func<HqlExpression, HqlExpression, HqlTreeNode> applyRegularComparison)
277278
{
278279
// Check for nulls on left or right.
279-
if (expression.Right is ConstantExpression && expression.Right.Type.IsNullableOrReference() &&
280-
((ConstantExpression) expression.Right).Value == null)
281-
{
280+
if (VisitorUtil.IsNullConstant(expression.Right))
282281
rhs = null;
283-
}
284-
285-
if (expression.Left is ConstantExpression && expression.Left.Type.IsNullableOrReference() &&
286-
((ConstantExpression) expression.Left).Value == null)
287-
{
282+
if (VisitorUtil.IsNullConstant(expression.Left))
288283
lhs = null;
289-
}
290284

291285
// Need to check for boolean equality
292286
if (lhs is HqlBooleanExpression || rhs is HqlBooleanExpression)

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que
236236

237237
public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
238238
{
239+
var visitor = new SimplifyConditionalVisitor();
240+
whereClause.Predicate = visitor.VisitExpression(whereClause.Predicate);
241+
239242
// Visit the predicate to build the query
240243
var expression = HqlGeneratorExpressionTreeVisitor.Visit(whereClause.Predicate, VisitorParameters).AsBooleanExpression();
241244
if (whereClause is NhHavingClause)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System.Linq.Expressions;
2+
using NHibernate.Util;
3+
using Remotion.Linq.Parsing;
4+
5+
namespace NHibernate.Linq.Visitors
6+
{
7+
/// <summary>
8+
/// Some conditional expressions can be redured to just their IfTrue or IfFalse part.
9+
/// </summary>
10+
internal class SimplifyConditionalVisitor :ExpressionTreeVisitor
11+
{
12+
protected override Expression VisitConditionalExpression(ConditionalExpression expression)
13+
{
14+
var testExpression = VisitExpression(expression.Test);
15+
16+
bool testExprResult;
17+
if (VisitorUtil.IsBooleanConstant(testExpression, out testExprResult))
18+
{
19+
if (testExprResult)
20+
return VisitExpression(expression.IfTrue);
21+
22+
return VisitExpression(expression.IfFalse);
23+
}
24+
25+
return base.VisitConditionalExpression(expression);
26+
}
27+
28+
29+
protected override Expression VisitBinaryExpression(BinaryExpression expression)
30+
{
31+
// See NH-3423. Conditional expression where the test expression is a comparison
32+
// of a construction expression and null will happen in WCF DS.
33+
34+
if (IsConstructionToNullComparison(expression))
35+
{
36+
// The result of a construction operation is always non-null. So if it's being compared to
37+
// a null constant, we can simplify it to a boolean constant.
38+
if (expression.NodeType == ExpressionType.Equal)
39+
return Expression.Constant(false);
40+
41+
if (expression.NodeType == ExpressionType.NotEqual)
42+
return Expression.Constant(true);
43+
}
44+
45+
return base.VisitBinaryExpression(expression);
46+
}
47+
48+
49+
private static bool IsConstruction(Expression expression)
50+
{
51+
return expression is NewExpression || expression is MemberInitExpression;
52+
}
53+
54+
55+
private static bool IsConstructionToNullComparison(Expression expression)
56+
{
57+
var testExpression = expression as BinaryExpression;
58+
59+
if (testExpression != null)
60+
{
61+
if ((IsConstruction(testExpression.Left) && VisitorUtil.IsNullConstant(testExpression.Right))
62+
|| (IsConstruction(testExpression.Right) && VisitorUtil.IsNullConstant(testExpression.Left)))
63+
{
64+
return true;
65+
}
66+
}
67+
68+
return false;
69+
}
70+
}
71+
}

src/NHibernate/Linq/Visitors/VisitorUtil.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq.Expressions;
44
using System.Collections;
55
using System.Reflection;
6+
using NHibernate.Util;
67

78
namespace NHibernate.Linq.Visitors
89
{
@@ -48,5 +49,27 @@ public static bool IsDynamicComponentDictionaryGetter(MethodCallExpression expre
4849
string memberName;
4950
return IsDynamicComponentDictionaryGetter(expression, sessionFactory, out memberName);
5051
}
52+
53+
54+
public static bool IsNullConstant(Expression expression)
55+
{
56+
return expression is ConstantExpression &&
57+
expression.Type.IsNullableOrReference() &&
58+
((ConstantExpression)expression).Value == null;
59+
}
60+
61+
62+
public static bool IsBooleanConstant(Expression expression, out bool value)
63+
{
64+
var constantExpr = expression as ConstantExpression;
65+
if (constantExpr != null && constantExpr.Type == typeof (bool))
66+
{
67+
value = (bool) constantExpr.Value;
68+
return true;
69+
}
70+
71+
value = false; // Dummy value.
72+
return false;
73+
}
5174
}
5275
}

src/NHibernate/Linq/Visitors/WhereJoinDetector.cs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,14 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression)
131131
{
132132
HandleBinaryOperation((a, b) => a.OrElse(b));
133133
}
134-
else if (expression.NodeType == ExpressionType.NotEqual && IsNullConstantExpression(expression.Right))
134+
else if (expression.NodeType == ExpressionType.NotEqual && VisitorUtil.IsNullConstant(expression.Right))
135135
{
136136
// Discard result from right null. Left is visited first, so it's below right on the stack.
137137
_values.Pop();
138138

139139
HandleUnaryOperation(pvs => pvs.IsNotNull());
140140
}
141-
else if (expression.NodeType == ExpressionType.NotEqual && IsNullConstantExpression(expression.Left))
141+
else if (expression.NodeType == ExpressionType.NotEqual && VisitorUtil.IsNullConstant(expression.Left))
142142
{
143143
// Discard result from left null.
144144
var right = _values.Pop();
@@ -147,14 +147,14 @@ protected override Expression VisitBinaryExpression(BinaryExpression expression)
147147

148148
HandleUnaryOperation(pvs => pvs.IsNotNull());
149149
}
150-
else if (expression.NodeType == ExpressionType.Equal && IsNullConstantExpression(expression.Right))
150+
else if (expression.NodeType == ExpressionType.Equal && VisitorUtil.IsNullConstant(expression.Right))
151151
{
152152
// Discard result from right null. Left is visited first, so it's below right on the stack.
153153
_values.Pop();
154154

155155
HandleUnaryOperation(pvs => pvs.IsNull());
156156
}
157-
else if (expression.NodeType == ExpressionType.Equal && IsNullConstantExpression(expression.Left))
157+
else if (expression.NodeType == ExpressionType.Equal && VisitorUtil.IsNullConstant(expression.Left))
158158
{
159159
// Discard result from left null.
160160
var right = _values.Pop();
@@ -324,12 +324,6 @@ protected override Expression VisitMemberExpression(MemberExpression expression)
324324
return result;
325325
}
326326

327-
private static bool IsNullConstantExpression(Expression expression)
328-
{
329-
var constant = expression as ConstantExpression;
330-
return constant != null && constant.Value == null;
331-
}
332-
333327
private void SetResultValues(ExpressionValues values)
334328
{
335329
_handled.Pop();

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@
318318
<Compile Include="Linq\NestedSelects\NestedSelectRewriter.cs" />
319319
<Compile Include="Linq\Visitors\SelectJoinDetector.cs" />
320320
<Compile Include="Linq\Visitors\SelectClauseNominator.cs" />
321+
<Compile Include="Linq\Visitors\SimplifyConditionalVisitor.cs" />
321322
<Compile Include="Linq\Visitors\SubQueryFromClauseFlattener.cs" />
322323
<Compile Include="Linq\Visitors\VisitorUtil.cs" />
323324
<Compile Include="Linq\Visitors\WhereJoinDetector.cs" />

0 commit comments

Comments
 (0)