Skip to content

Commit 8fb1956

Browse files
committed
Extend the logic to be used for other aggregate functions
1 parent 5050805 commit 8fb1956

File tree

1 file changed

+104
-23
lines changed

1 file changed

+104
-23
lines changed

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Data;
23
using System.Dynamic;
34
using System.Linq;
45
using System.Linq.Expressions;
@@ -240,10 +241,13 @@ constant.Value is CallSite site &&
240241
protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
241242
{
242243
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
243-
if (expression.Type != expression.Expression.Type)
244-
hqlExpression = _hqlTreeBuilder.Cast(hqlExpression, expression.Type);
244+
hqlExpression = IsCastRequired(expression.Expression, expression.Type)
245+
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
246+
: _hqlTreeBuilder.TransparentCast(hqlExpression, expression.Type);
245247

246-
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
248+
return IsCastRequired(expression.Type, "avg")
249+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Average(hqlExpression), expression.Type)
250+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Average(hqlExpression), expression.Type);
247251
}
248252

249253
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
@@ -263,17 +267,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
263267

264268
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
265269
{
266-
var type = expression.Type.UnwrapIfNullable();
267-
var nhType = TypeFactory.GetDefaultTypeFor(type);
268-
if (nhType != null && _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("sum")
269-
?.ReturnType(nhType, _parameters.SessionFactory)?.ReturnedClass == type)
270-
{
271-
return _hqlTreeBuilder.TransparentCast(
272-
_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()),
273-
expression.Type);
274-
}
275-
276-
return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
270+
return IsCastRequired(expression.Type, "sum")
271+
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type)
272+
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
277273
}
278274

279275
protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression)
@@ -487,15 +483,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
487483
case ExpressionType.Convert:
488484
case ExpressionType.ConvertChecked:
489485
case ExpressionType.TypeAs:
490-
var operandType = expression.Operand.Type.UnwrapIfNullable();
491-
if ((operandType.IsPrimitive || operandType == typeof(decimal)) &&
492-
(expression.Type.IsPrimitive || expression.Type == typeof(decimal)) &&
493-
expression.Type != operandType)
494-
{
495-
return _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type);
496-
}
497-
498-
return VisitExpression(expression.Operand);
486+
return IsCastRequired(expression.Operand, expression.Type)
487+
? _hqlTreeBuilder.Cast(VisitExpression(expression.Operand).AsExpression(), expression.Type)
488+
: VisitExpression(expression.Operand);
499489
}
500490

501491
throw new NotSupportedException(expression.ToString());
@@ -596,5 +586,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
596586
var expressionSubTree = expression.Expressions.ToArray(exp => VisitExpression(exp));
597587
return _hqlTreeBuilder.ExpressionSubTreeHolder(expressionSubTree);
598588
}
589+
590+
private bool IsCastRequired(Expression expression, System.Type toType)
591+
{
592+
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType));
593+
}
594+
595+
private bool IsCastRequired(IType type, IType toType)
596+
{
597+
// A type can be null when casting an entity into a base class, in that case we should not cast
598+
if (type == null || toType == null || Equals(type, toType))
599+
{
600+
return false;
601+
}
602+
603+
var sqlTypes = type.SqlTypes(_parameters.SessionFactory);
604+
var toSqlTypes = toType.SqlTypes(_parameters.SessionFactory);
605+
if (sqlTypes.Length != 1 || toSqlTypes.Length != 1)
606+
{
607+
return false; // Casting a multi-column type is not possible
608+
}
609+
610+
if (type.ReturnedClass.IsEnum && sqlTypes[0].DbType == DbType.String)
611+
{
612+
return false; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
613+
}
614+
615+
return sqlTypes[0].DbType != toSqlTypes[0].DbType;
616+
}
617+
618+
private bool IsCastRequired(System.Type type, string sqlFunctionName)
619+
{
620+
if (type == typeof(object))
621+
{
622+
return false;
623+
}
624+
625+
var toType = TypeFactory.GetDefaultTypeFor(type);
626+
if (toType == null)
627+
{
628+
return true; // Fallback to the old behavior
629+
}
630+
631+
var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName);
632+
if (sqlFunction == null)
633+
{
634+
return true; // Fallback to the old behavior
635+
}
636+
637+
var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory);
638+
return fnReturnType == null || IsCastRequired(fnReturnType, toType);
639+
}
640+
641+
private IType GetType(Expression expression)
642+
{
643+
if (!(expression is MemberExpression memberExpression))
644+
{
645+
return expression.Type != typeof(object)
646+
? TypeFactory.GetDefaultTypeFor(expression.Type)
647+
: null;
648+
}
649+
650+
// Try to get the mapped type for the member as it may be a non default one
651+
var entityName = TryGetEntityName(memberExpression);
652+
if (entityName == null)
653+
{
654+
return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped
655+
}
656+
657+
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
658+
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name);
659+
return !index.HasValue
660+
? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped
661+
: persister.EntityMetamodel.PropertyTypes[index.Value];
662+
}
663+
664+
private string TryGetEntityName(MemberExpression memberExpression)
665+
{
666+
System.Type entityType;
667+
// Try to get the actual entity type from the query source if possbile as member can be declared
668+
// in a base type
669+
if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression)
670+
{
671+
entityType = querySourceReferenceExpression.Type;
672+
}
673+
else
674+
{
675+
entityType = memberExpression.Member.ReflectedType;
676+
}
677+
678+
return _parameters.SessionFactory.TryGetGuessEntityName(entityType);
679+
}
599680
}
600681
}

0 commit comments

Comments
 (0)