@@ -240,6 +240,9 @@ constant.Value is CallSite site &&
240
240
241
241
protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
242
242
{
243
+ // We need to cast the argument when its type is different from Average method return type,
244
+ // otherwise the result may be incorrect. In SQL Server avg always returns int
245
+ // when the argument is int.
243
246
var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
244
247
hqlExpression = IsCastRequired ( expression . Expression , expression . Type , out _ )
245
248
? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
@@ -267,7 +270,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
267
270
268
271
protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
269
272
{
270
- return IsCastRequired ( expression . Type , "sum" , out _ )
273
+ return IsCastRequired ( "sum" , expression . Expression , expression . Type )
271
274
? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
272
275
: _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
273
276
}
@@ -593,7 +596,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
593
596
private bool IsCastRequired ( Expression expression , System . Type toType , out bool existType )
594
597
{
595
598
existType = false ;
596
- return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
599
+ return toType != typeof ( object ) &&
600
+ IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
597
601
}
598
602
599
603
private bool IsCastRequired ( IType type , IType toType , out bool existType )
@@ -635,59 +639,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
635
639
return castTypeName != toCastTypeName ;
636
640
}
637
641
638
- private bool IsCastRequired ( System . Type type , string sqlFunctionName , out bool existType )
642
+ private bool IsCastRequired ( string sqlFunctionName , Expression argumentExpression , System . Type returnType )
639
643
{
640
- if ( type == typeof ( object ) )
644
+ var argumentType = GetType ( argumentExpression ) ;
645
+ if ( argumentType == null || returnType == typeof ( object ) )
641
646
{
642
- existType = false ;
643
647
return false ;
644
648
}
645
649
646
- var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
647
- if ( toType == null )
650
+ var returnNhType = TypeFactory . GetDefaultTypeFor ( returnType ) ;
651
+ if ( returnNhType == null )
648
652
{
649
- existType = false ;
650
653
return true ; // Fallback to the old behavior
651
654
}
652
655
653
- existType = true ;
654
656
var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
655
657
if ( sqlFunction == null )
656
658
{
657
659
return true ; // Fallback to the old behavior
658
660
}
659
661
660
- var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
661
- return fnReturnType == null || IsCastRequired ( fnReturnType , toType , out existType ) ;
662
+ var fnReturnType = sqlFunction . ReturnType ( argumentType , _parameters . SessionFactory ) ;
663
+ return fnReturnType == null || IsCastRequired ( fnReturnType , returnNhType , out _ ) ;
662
664
}
663
665
664
666
private IType GetType ( Expression expression )
665
667
{
666
- if ( ! ( expression is MemberExpression memberExpression ) )
667
- {
668
- return expression . Type != typeof ( object )
669
- ? TypeFactory . GetDefaultTypeFor ( expression . Type )
670
- : null ;
671
- }
672
-
673
668
// Try to get the mapped type for the member as it may be a non default one
674
- var entityName = ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , memberExpression , out var memberPath ) ;
675
- if ( entityName == null )
676
- {
677
- return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
678
- }
679
-
680
- var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
681
- var type = persister . EntityMetamodel . GetIdentifierPropertyType ( memberPath ) ;
682
- if ( type != null )
683
- {
684
- return type ;
685
- }
686
-
687
- var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberPath ) ;
688
- return ! index . HasValue
689
- ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
690
- : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
669
+ ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , expression , out _ , out var type ) ;
670
+ return type ??
671
+ ( expression . Type != typeof ( object )
672
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
673
+ : null ) ;
691
674
}
692
675
}
693
676
}
0 commit comments