Skip to content

Commit 40b830d

Browse files
committed
Moved nullable check code into a separated class
1 parent c273c20 commit 40b830d

File tree

2 files changed

+316
-293
lines changed

2 files changed

+316
-293
lines changed

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 7 additions & 293 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using System;
2-
using System.Collections.Generic;
32
using System.Data;
43
using System.Dynamic;
54
using System.Linq;
@@ -8,16 +7,12 @@
87
using NHibernate.Engine.Query;
98
using NHibernate.Hql.Ast;
109
using NHibernate.Hql.Ast.ANTLR;
11-
using NHibernate.Linq.Clauses;
1210
using NHibernate.Linq.Expressions;
1311
using NHibernate.Linq.Functions;
14-
using NHibernate.Mapping.ByCode;
1512
using NHibernate.Param;
1613
using NHibernate.Type;
1714
using NHibernate.Util;
18-
using Remotion.Linq.Clauses;
1915
using Remotion.Linq.Clauses.Expressions;
20-
using Remotion.Linq.Clauses.ResultOperators;
2116

2217
namespace NHibernate.Linq.Visitors
2318
{
@@ -26,17 +21,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor
2621
private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder();
2722
private readonly VisitorParameters _parameters;
2823
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
29-
private readonly Dictionary<BinaryExpression, List<MemberExpression>> _equalityNotNullMembers =
30-
new Dictionary<BinaryExpression, List<MemberExpression>>();
31-
32-
private static readonly HashSet<System.Type> NotNullOperators = new HashSet<System.Type>()
33-
{
34-
typeof(AllResultOperator),
35-
typeof(AnyResultOperator),
36-
typeof(ContainsResultOperator),
37-
typeof(CountResultOperator),
38-
typeof(LongCountResultOperator)
39-
};
24+
private readonly NullableExpressionDetector _nullableExpressionDetector;
4025

4126
public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters)
4227
{
@@ -47,6 +32,7 @@ public HqlGeneratorExpressionVisitor(VisitorParameters parameters)
4732
{
4833
_functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry;
4934
_parameters = parameters;
35+
_nullableExpressionDetector = new NullableExpressionDetector(_parameters.SessionFactory, _functionRegistry);
5036
}
5137

5238
public ISessionFactory SessionFactory { get { return _parameters.SessionFactory; } }
@@ -308,94 +294,6 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi
308294
return VisitExpression(expression.Comparison);
309295
}
310296

311-
private void SearchForNotNullMembersCheck(BinaryExpression expression)
312-
{
313-
// Check for a member not null check that has a not equals expression
314-
// Example: o.Status != null && o.Status != "New"
315-
// Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus)
316-
// Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus)
317-
if (expression.NodeType != ExpressionType.AndAlso ||
318-
expression.Right.NodeType != ExpressionType.NotEqual &&
319-
expression.Right.NodeType != ExpressionType.Equal ||
320-
expression.Left.NodeType != ExpressionType.AndAlso &&
321-
expression.Left.NodeType != ExpressionType.NotEqual)
322-
{
323-
return;
324-
}
325-
326-
// Skip if there are no member access expressions on the right side
327-
var notEqualExpression = (BinaryExpression) expression.Right;
328-
if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right))
329-
{
330-
return;
331-
}
332-
333-
var notNullMembers = new List<MemberExpression>();
334-
// We may have multiple conditions
335-
// Example: o.Status != null && o.OldStatus != null
336-
if (expression.Left.NodeType == ExpressionType.AndAlso)
337-
{
338-
FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers);
339-
}
340-
else
341-
{
342-
FindNotNullMember((BinaryExpression) expression.Left, notNullMembers);
343-
}
344-
345-
if (notNullMembers.Count > 0)
346-
{
347-
_equalityNotNullMembers[notEqualExpression] = notNullMembers;
348-
}
349-
}
350-
351-
private static bool IsMemberAccess(Expression expression)
352-
{
353-
if (expression.NodeType == ExpressionType.MemberAccess)
354-
{
355-
return true;
356-
}
357-
358-
// Nullable members can be wrapped in a convert expression
359-
return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess;
360-
}
361-
362-
private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List<MemberExpression> notNullMembers)
363-
{
364-
if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual)
365-
{
366-
FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers);
367-
}
368-
else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso)
369-
{
370-
FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers);
371-
}
372-
else
373-
{
374-
return;
375-
}
376-
377-
if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual)
378-
{
379-
FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers);
380-
}
381-
else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso)
382-
{
383-
FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers);
384-
}
385-
}
386-
387-
private static void FindNotNullMember(BinaryExpression notEqualExpression, List<MemberExpression> notNullMembers)
388-
{
389-
if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right))
390-
{
391-
notNullMembers.Add((MemberExpression) notEqualExpression.Left);
392-
}
393-
else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess)
394-
{
395-
notNullMembers.Add((MemberExpression) notEqualExpression.Right);
396-
}
397-
}
398-
399297
protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
400298
{
401299
if (expression.NodeType == ExpressionType.Equal)
@@ -407,7 +305,7 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression)
407305
return TranslateInequalityComparison(expression);
408306
}
409307

410-
SearchForNotNullMembersCheck(expression);
308+
_nullableExpressionDetector.SearchForNotNullMemberChecks(expression);
411309

412310
var lhs = VisitExpression(expression.Left).AsExpression();
413311
var rhs = VisitExpression(expression.Right).AsExpression();
@@ -490,8 +388,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression)
490388
return _hqlTreeBuilder.IsNotNull(lhs);
491389
}
492390

493-
var lhsNullable = IsNullable(expression.Left, expression);
494-
var rhsNullable = IsNullable(expression.Right, expression);
391+
var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression);
392+
var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression);
495393

496394
var inequality = _hqlTreeBuilder.Inequality(lhs, rhs);
497395

@@ -553,8 +451,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression)
553451
return _hqlTreeBuilder.IsNull((lhs));
554452
}
555453

556-
var lhsNullable = IsNullable(expression.Left, expression);
557-
var rhsNullable = IsNullable(expression.Right, expression);
454+
var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression);
455+
var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression);
558456

559457
var equality = _hqlTreeBuilder.Equality(lhs, rhs);
560458

@@ -573,190 +471,6 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression)
573471
_hqlTreeBuilder.IsNull(rhs2)));
574472
}
575473

576-
private bool IsNullable(Expression expression, BinaryExpression equalityExpression)
577-
{
578-
var currentExpression = expression;
579-
while (true)
580-
{
581-
switch (currentExpression.NodeType)
582-
{
583-
case ExpressionType.Convert:
584-
case ExpressionType.ConvertChecked:
585-
case ExpressionType.TypeAs:
586-
var unaryExpression = (UnaryExpression) currentExpression;
587-
return IsNullable(unaryExpression.Operand, equalityExpression); // a cast will not return null if the operand is not null
588-
case ExpressionType.Not:
589-
case ExpressionType.And:
590-
case ExpressionType.Or:
591-
case ExpressionType.ExclusiveOr:
592-
case ExpressionType.LeftShift:
593-
case ExpressionType.RightShift:
594-
case ExpressionType.AndAlso:
595-
case ExpressionType.OrElse:
596-
case ExpressionType.Equal:
597-
case ExpressionType.NotEqual:
598-
case ExpressionType.GreaterThanOrEqual:
599-
case ExpressionType.GreaterThan:
600-
case ExpressionType.LessThan:
601-
case ExpressionType.LessThanOrEqual:
602-
return false;
603-
case ExpressionType.Add:
604-
case ExpressionType.AddChecked:
605-
case ExpressionType.Divide:
606-
case ExpressionType.Modulo:
607-
case ExpressionType.Multiply:
608-
case ExpressionType.MultiplyChecked:
609-
case ExpressionType.Power:
610-
case ExpressionType.Subtract:
611-
case ExpressionType.SubtractChecked:
612-
var binaryExpression = (BinaryExpression) currentExpression;
613-
return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression);
614-
case ExpressionType.ArrayIndex:
615-
return true; // for indexed lists we cannot determine whether the item will be null or not
616-
case ExpressionType.Coalesce:
617-
return IsNullable(((BinaryExpression) currentExpression).Right, equalityExpression);
618-
case ExpressionType.Conditional:
619-
var conditionalExpression = (ConditionalExpression) currentExpression;
620-
return IsNullable(conditionalExpression.IfTrue, equalityExpression) ||
621-
IsNullable(conditionalExpression.IfFalse, equalityExpression);
622-
case ExpressionType.Call:
623-
var methodInfo = ((MethodCallExpression) currentExpression).Method;
624-
return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo);
625-
case ExpressionType.MemberAccess:
626-
var memberExpression = (MemberExpression) currentExpression;
627-
628-
if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _))
629-
{
630-
// We have to skip the property as it will be converted to a function that can return null
631-
// if the argument is null
632-
currentExpression = memberExpression.Expression;
633-
continue;
634-
}
635-
636-
var memberType = ReflectHelper.GetPropertyOrFieldType(memberExpression.Member);
637-
if (memberType?.IsValueType == true && !memberType.IsNullable())
638-
{
639-
currentExpression = memberExpression.Expression;
640-
continue;
641-
}
642-
643-
// Check if there was a not null check prior the equality expression
644-
if ((
645-
equalityExpression.NodeType == ExpressionType.NotEqual ||
646-
equalityExpression.NodeType == ExpressionType.Equal
647-
) &&
648-
_equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) &&
649-
notNullMembers.Any(o => AreEqual(o, memberExpression)))
650-
{
651-
return false;
652-
}
653-
654-
// We have to check the member mapping to determine if is nullable
655-
var entityName = TryGetEntityName(memberExpression);
656-
if (entityName == null)
657-
{
658-
return true; // not mapped
659-
}
660-
661-
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
662-
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name);
663-
if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value])
664-
{
665-
return true; // not mapped or nullable
666-
}
667-
668-
currentExpression = memberExpression.Expression;
669-
continue;
670-
case ExpressionType.Extension:
671-
switch (currentExpression)
672-
{
673-
case QuerySourceReferenceExpression querySourceReferenceExpression:
674-
switch (querySourceReferenceExpression.ReferencedQuerySource)
675-
{
676-
case MainFromClause _:
677-
return false; // we reached to the root expression, there were no nullable expressions
678-
case NhJoinClause joinClause:
679-
return IsNullable(joinClause.FromExpression, equalityExpression);
680-
default:
681-
return true; // unknown query source
682-
}
683-
case SubQueryExpression subQuery:
684-
if (subQuery.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression)
685-
{
686-
return subQueryAggregatedExpression.AllowsNullableReturnType;
687-
}
688-
else if (subQuery.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType())))
689-
{
690-
return false;
691-
}
692-
693-
return true;
694-
case NhAggregatedExpression aggregatedExpression:
695-
return aggregatedExpression.AllowsNullableReturnType;
696-
default:
697-
return true; // a query can return null and we cannot calculate it as it is not yet executed
698-
}
699-
case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null
700-
case ExpressionType.NewArrayInit:
701-
return false;
702-
case ExpressionType.Constant:
703-
return VisitorUtil.IsNullConstant(currentExpression);
704-
case ExpressionType.Parameter:
705-
return !currentExpression.Type.IsValueType;
706-
default:
707-
return true;
708-
}
709-
}
710-
}
711-
712-
private bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression)
713-
{
714-
if (memberExpression.Member != otherMemberExpression.Member ||
715-
memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType)
716-
{
717-
return false;
718-
}
719-
720-
switch (memberExpression.Expression)
721-
{
722-
case QuerySourceReferenceExpression querySourceReferenceExpression:
723-
if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression)
724-
{
725-
return querySourceReferenceExpression.ReferencedQuerySource ==
726-
otherQuerySourceReferenceExpression.ReferencedQuerySource;
727-
}
728-
729-
return false;
730-
// Components have a nested member expression
731-
case MemberExpression nestedMemberExpression:
732-
if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression)
733-
{
734-
return AreEqual(nestedMemberExpression, otherNestedMemberExpression);
735-
}
736-
737-
return false;
738-
default:
739-
return memberExpression.Expression == otherMemberExpression.Expression;
740-
}
741-
}
742-
743-
private string TryGetEntityName(MemberExpression memberExpression)
744-
{
745-
System.Type entityType;
746-
// Try to get the actual entity type from the query source if possbile as member can be declared
747-
// in a base type
748-
if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression)
749-
{
750-
entityType = querySourceReferenceExpression.Type;
751-
}
752-
else
753-
{
754-
entityType = memberExpression.Member.ReflectedType;
755-
}
756-
757-
return _parameters.SessionFactory.TryGetGuessEntityName(entityType);
758-
}
759-
760474
protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
761475
{
762476
switch (expression.NodeType)

0 commit comments

Comments
 (0)