1
1
using System ;
2
+ using System . Data ;
2
3
using System . Dynamic ;
3
4
using System . Linq ;
4
5
using System . Linq . Expressions ;
@@ -240,10 +241,13 @@ constant.Value is CallSite site &&
240
241
protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
241
242
{
242
243
var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
243
- if ( expression . Type != expression . Expression . Type )
244
- hqlExpression = _hqlTreeBuilder . Cast ( hqlExpression , expression . Type ) ;
244
+ hqlExpression = IsCastRequired ( expression . Expression , expression . Type )
245
+ ? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
246
+ : _hqlTreeBuilder . TransparentCast ( hqlExpression , expression . Type ) ;
245
247
246
- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
248
+ return IsCastRequired ( expression . Type , "avg" )
249
+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type )
250
+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
247
251
}
248
252
249
253
protected HqlTreeNode VisitNhCount ( NhCountExpression expression )
@@ -263,17 +267,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
263
267
264
268
protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
265
269
{
266
- var type = expression . Type . UnwrapIfNullable ( ) ;
267
- var nhType = TypeFactory . GetDefaultTypeFor ( type ) ;
268
- if ( nhType != null && _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( "sum" )
269
- ? . ReturnType ( nhType , _parameters . SessionFactory ) ? . ReturnedClass == type )
270
- {
271
- return _hqlTreeBuilder . TransparentCast (
272
- _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) ,
273
- expression . Type ) ;
274
- }
275
-
276
- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
270
+ return IsCastRequired ( expression . Type , "sum" )
271
+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
272
+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
277
273
}
278
274
279
275
protected HqlTreeNode VisitNhDistinct ( NhDistinctExpression expression )
@@ -487,15 +483,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
487
483
case ExpressionType . Convert :
488
484
case ExpressionType . ConvertChecked :
489
485
case ExpressionType . TypeAs :
490
- var operandType = expression . Operand . Type . UnwrapIfNullable ( ) ;
491
- if ( ( operandType . IsPrimitive || operandType == typeof ( decimal ) ) &&
492
- ( expression . Type . IsPrimitive || expression . Type == typeof ( decimal ) ) &&
493
- expression . Type != operandType )
494
- {
495
- return _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type ) ;
496
- }
497
-
498
- return VisitExpression ( expression . Operand ) ;
486
+ return IsCastRequired ( expression . Operand , expression . Type )
487
+ ? _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type )
488
+ : VisitExpression ( expression . Operand ) ;
499
489
}
500
490
501
491
throw new NotSupportedException ( expression . ToString ( ) ) ;
@@ -596,5 +586,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
596
586
var expressionSubTree = expression . Expressions . ToArray ( exp => VisitExpression ( exp ) ) ;
597
587
return _hqlTreeBuilder . ExpressionSubTreeHolder ( expressionSubTree ) ;
598
588
}
589
+
590
+ private bool IsCastRequired ( Expression expression , System . Type toType )
591
+ {
592
+ return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) ) ;
593
+ }
594
+
595
+ private bool IsCastRequired ( IType type , IType toType )
596
+ {
597
+ // A type can be null when casting an entity into a base class, in that case we should not cast
598
+ if ( type == null || toType == null || Equals ( type , toType ) )
599
+ {
600
+ return false ;
601
+ }
602
+
603
+ var sqlTypes = type . SqlTypes ( _parameters . SessionFactory ) ;
604
+ var toSqlTypes = toType . SqlTypes ( _parameters . SessionFactory ) ;
605
+ if ( sqlTypes . Length != 1 || toSqlTypes . Length != 1 )
606
+ {
607
+ return false ; // Casting a multi-column type is not possible
608
+ }
609
+
610
+ if ( type . ReturnedClass . IsEnum && sqlTypes [ 0 ] . DbType == DbType . String )
611
+ {
612
+ return false ; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
613
+ }
614
+
615
+ return sqlTypes [ 0 ] . DbType != toSqlTypes [ 0 ] . DbType ;
616
+ }
617
+
618
+ private bool IsCastRequired ( System . Type type , string sqlFunctionName )
619
+ {
620
+ if ( type == typeof ( object ) )
621
+ {
622
+ return false ;
623
+ }
624
+
625
+ var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
626
+ if ( toType == null )
627
+ {
628
+ return true ; // Fallback to the old behavior
629
+ }
630
+
631
+ var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
632
+ if ( sqlFunction == null )
633
+ {
634
+ return true ; // Fallback to the old behavior
635
+ }
636
+
637
+ var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
638
+ return fnReturnType == null || IsCastRequired ( fnReturnType , toType ) ;
639
+ }
640
+
641
+ private IType GetType ( Expression expression )
642
+ {
643
+ if ( ! ( expression is MemberExpression memberExpression ) )
644
+ {
645
+ return expression . Type != typeof ( object )
646
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
647
+ : null ;
648
+ }
649
+
650
+ // Try to get the mapped type for the member as it may be a non default one
651
+ var entityName = TryGetEntityName ( memberExpression ) ;
652
+ if ( entityName == null )
653
+ {
654
+ return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
655
+ }
656
+
657
+ var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
658
+ var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberExpression . Member . Name ) ;
659
+ return ! index . HasValue
660
+ ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
661
+ : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
662
+ }
663
+
664
+ private string TryGetEntityName ( MemberExpression memberExpression )
665
+ {
666
+ System . Type entityType ;
667
+ // Try to get the actual entity type from the query source if possbile as member can be declared
668
+ // in a base type
669
+ if ( memberExpression . Expression is QuerySourceReferenceExpression querySourceReferenceExpression )
670
+ {
671
+ entityType = querySourceReferenceExpression . Type ;
672
+ }
673
+ else
674
+ {
675
+ entityType = memberExpression . Member . ReflectedType ;
676
+ }
677
+
678
+ return _parameters . SessionFactory . TryGetGuessEntityName ( entityType ) ;
679
+ }
599
680
}
600
681
}
0 commit comments