@@ -242,6 +242,9 @@ constant.Value is CallSite site &&
242
242
243
243
protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
244
244
{
245
+ // We need to cast the argument when its type is different from Average method return type,
246
+ // otherwise the result may be incorrect. In SQL Server avg always returns int
247
+ // when the argument is int.
245
248
var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
246
249
hqlExpression = IsCastRequired ( expression . Expression , expression . Type , out _ )
247
250
? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
@@ -269,7 +272,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
269
272
270
273
protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
271
274
{
272
- return IsCastRequired ( expression . Type , "sum" , out _ )
275
+ return IsCastRequired ( "sum" , expression . Expression , expression . Type )
273
276
? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
274
277
: _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
275
278
}
@@ -595,7 +598,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
595
598
private bool IsCastRequired ( Expression expression , System . Type toType , out bool existType )
596
599
{
597
600
existType = false ;
598
- return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
601
+ return toType != typeof ( object ) &&
602
+ IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
599
603
}
600
604
601
605
private bool IsCastRequired ( IType type , IType toType , out bool existType )
@@ -637,59 +641,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
637
641
return castTypeName != toCastTypeName ;
638
642
}
639
643
640
- private bool IsCastRequired ( System . Type type , string sqlFunctionName , out bool existType )
644
+ private bool IsCastRequired ( string sqlFunctionName , Expression argumentExpression , System . Type returnType )
641
645
{
642
- if ( type == typeof ( object ) )
646
+ var argumentType = GetType ( argumentExpression ) ;
647
+ if ( argumentType == null || returnType == typeof ( object ) )
643
648
{
644
- existType = false ;
645
649
return false ;
646
650
}
647
651
648
- var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
649
- if ( toType == null )
652
+ var returnNhType = TypeFactory . GetDefaultTypeFor ( returnType ) ;
653
+ if ( returnNhType == null )
650
654
{
651
- existType = false ;
652
655
return true ; // Fallback to the old behavior
653
656
}
654
657
655
- existType = true ;
656
658
var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
657
659
if ( sqlFunction == null )
658
660
{
659
661
return true ; // Fallback to the old behavior
660
662
}
661
663
662
- var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
663
- return fnReturnType == null || IsCastRequired ( fnReturnType , toType , out existType ) ;
664
+ var fnReturnType = sqlFunction . ReturnType ( argumentType , _parameters . SessionFactory ) ;
665
+ return fnReturnType == null || IsCastRequired ( fnReturnType , returnNhType , out _ ) ;
664
666
}
665
667
666
668
private IType GetType ( Expression expression )
667
669
{
668
- if ( ! ( expression is MemberExpression memberExpression ) )
669
- {
670
- return expression . Type != typeof ( object )
671
- ? TypeFactory . GetDefaultTypeFor ( expression . Type )
672
- : null ;
673
- }
674
-
675
670
// Try to get the mapped type for the member as it may be a non default one
676
- var entityName = ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , memberExpression , out var memberPath ) ;
677
- if ( entityName == null )
678
- {
679
- return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
680
- }
681
-
682
- var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
683
- var type = persister . EntityMetamodel . GetIdentifierPropertyType ( memberPath ) ;
684
- if ( type != null )
685
- {
686
- return type ;
687
- }
688
-
689
- var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberPath ) ;
690
- return ! index . HasValue
691
- ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
692
- : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
671
+ ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , expression , out _ , out var type ) ;
672
+ return type ??
673
+ ( expression . Type != typeof ( object )
674
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
675
+ : null ) ;
693
676
}
694
677
}
695
678
}
0 commit comments