Skip to content

Commit 6f32861

Browse files
committed
Fix TryGetEntityName for custom entity names
1 parent 4fe3d7f commit 6f32861

File tree

4 files changed

+372
-80
lines changed

4 files changed

+372
-80
lines changed

src/NHibernate/Linq/Functions/ListIndexerGenerator.cs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,30 @@ namespace NHibernate.Linq.Functions
1212
{
1313
internal class ListIndexerGenerator : BaseHqlGeneratorForMethod,IRuntimeMethodHqlGenerator
1414
{
15+
private static readonly HashSet<MethodInfo> _supportedMethods = new HashSet<MethodInfo>
16+
{
17+
ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt<object>(null, 0)),
18+
ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt<object>(null, 0))
19+
};
20+
1521
public ListIndexerGenerator()
1622
{
17-
SupportedMethods = new[]
18-
{
19-
ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt<object>(null, 0)),
20-
ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt<object>(null, 0))
21-
};
23+
SupportedMethods = _supportedMethods;
2224
}
2325

2426
public bool SupportsMethod(MethodInfo method)
2527
{
26-
return method != null &&
27-
method.Name == "get_Item" &&
28-
(method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>)));
28+
return IsRuntimeMethodSupported(method);
29+
}
30+
31+
public static bool IsMethodSupported(MethodInfo method)
32+
{
33+
if (method.IsGenericMethod)
34+
{
35+
method = method.GetGenericMethodDefinition();
36+
}
37+
38+
return _supportedMethods.Contains(method) || IsRuntimeMethodSupported(method);
2939
}
3040

3141
public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
@@ -40,5 +50,12 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
4050

4151
return treeBuilder.Index(collection, index);
4252
}
53+
54+
private static bool IsRuntimeMethodSupported(MethodInfo method)
55+
{
56+
return method != null &&
57+
method.Name == "get_Item" &&
58+
(method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>)));
59+
}
4360
}
44-
}
61+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ constant.Value is CallSite site &&
240240

241241
protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
242242
{
243+
// We need to cast the argument when its type is different from Average method return type,
244+
// otherwise the result may be incorrect. In SQL Server avg always returns int
245+
// when the argument is int.
243246
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
244247
hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _)
245248
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
@@ -267,7 +270,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
267270

268271
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
269272
{
270-
return IsCastRequired(expression.Type, "sum", out _)
273+
return IsCastRequired("sum", expression.Expression, expression.Type)
271274
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type)
272275
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
273276
}
@@ -593,7 +596,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
593596
private bool IsCastRequired(Expression expression, System.Type toType, out bool existType)
594597
{
595598
existType = false;
596-
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
599+
return toType != typeof(object) &&
600+
IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
597601
}
598602

599603
private bool IsCastRequired(IType type, IType toType, out bool existType)
@@ -635,59 +639,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
635639
return castTypeName != toCastTypeName;
636640
}
637641

638-
private bool IsCastRequired(System.Type type, string sqlFunctionName, out bool existType)
642+
private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType)
639643
{
640-
if (type == typeof(object))
644+
var argumentType = GetType(argumentExpression);
645+
if (argumentType == null || returnType == typeof(object))
641646
{
642-
existType = false;
643647
return false;
644648
}
645649

646-
var toType = TypeFactory.GetDefaultTypeFor(type);
647-
if (toType == null)
650+
var returnNhType = TypeFactory.GetDefaultTypeFor(returnType);
651+
if (returnNhType == null)
648652
{
649-
existType = false;
650653
return true; // Fallback to the old behavior
651654
}
652655

653-
existType = true;
654656
var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName);
655657
if (sqlFunction == null)
656658
{
657659
return true; // Fallback to the old behavior
658660
}
659661

660-
var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory);
661-
return fnReturnType == null || IsCastRequired(fnReturnType, toType, out existType);
662+
var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory);
663+
return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _);
662664
}
663665

664666
private IType GetType(Expression expression)
665667
{
666-
if (!(expression is MemberExpression memberExpression))
667-
{
668-
return expression.Type != typeof(object)
669-
? TypeFactory.GetDefaultTypeFor(expression.Type)
670-
: null;
671-
}
672-
673668
// Try to get the mapped type for the member as it may be a non default one
674-
var entityName = ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, memberExpression, out var memberPath);
675-
if (entityName == null)
676-
{
677-
return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped
678-
}
679-
680-
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
681-
var type = persister.EntityMetamodel.GetIdentifierPropertyType(memberPath);
682-
if (type != null)
683-
{
684-
return type;
685-
}
686-
687-
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath);
688-
return !index.HasValue
689-
? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped
690-
: persister.EntityMetamodel.PropertyTypes[index.Value];
669+
ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, expression, out _, out var type);
670+
return type ??
671+
(expression.Type != typeof(object)
672+
? TypeFactory.GetDefaultTypeFor(expression.Type)
673+
: null);
691674
}
692675
}
693676
}

src/NHibernate/Tuple/Entity/EntityMetamodel.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public class EntityMetamodel
5151

5252
private readonly Dictionary<string, int?> propertyIndexes = new Dictionary<string, int?>();
5353
private readonly IDictionary<string, IType> _identifierPropertyTypes = new Dictionary<string, IType>();
54+
private readonly IDictionary<string, IType> _propertyTypes = new Dictionary<string, IType>();
5455
private readonly bool hasCollections;
5556
private readonly bool hasMutableProperties;
5657
private readonly bool hasLazyProperties;
@@ -416,15 +417,17 @@ private void MapPropertyToIndex(Mapping.Property prop, int i)
416417

417418
private void MapPropertyToIndex(string path, Mapping.Property prop, int i)
418419
{
419-
propertyIndexes[!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name] = i;
420+
var propPath = !string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name;
421+
propertyIndexes[propPath] = i;
422+
_propertyTypes[propPath] = prop.Type;
420423
if (!(prop.Value is Mapping.Component comp))
421424
{
422425
return;
423426
}
424427

425428
foreach (var subprop in comp.PropertyIterator)
426429
{
427-
MapPropertyToIndex(!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name, subprop, i);
430+
MapPropertyToIndex(propPath, subprop, i);
428431
}
429432
}
430433

@@ -570,6 +573,13 @@ internal IType GetIdentifierPropertyType(string memberPath)
570573
return _identifierPropertyTypes.TryGetValue(memberPath, out var propertyType) ? propertyType : null;
571574
}
572575

576+
internal IType GetPropertyType(string memberPath)
577+
{
578+
return _propertyTypes.TryGetValue(memberPath, out var propertyType)
579+
? propertyType
580+
: GetIdentifierPropertyType(memberPath);
581+
}
582+
573583
public bool HasCollections
574584
{
575585
get { return hasCollections; }

0 commit comments

Comments
 (0)