Skip to content

Commit 4228d69

Browse files
committed
Fix TryGetEntityName for custom entity names
1 parent 369bb31 commit 4228d69

File tree

4 files changed

+372
-80
lines changed

4 files changed

+372
-80
lines changed

src/NHibernate/Linq/Functions/ListIndexerGenerator.cs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,30 @@ namespace NHibernate.Linq.Functions
1212
{
1313
internal class ListIndexerGenerator : BaseHqlGeneratorForMethod,IRuntimeMethodHqlGenerator
1414
{
15+
private static readonly HashSet<MethodInfo> _supportedMethods = new HashSet<MethodInfo>
16+
{
17+
ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt<object>(null, 0)),
18+
ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt<object>(null, 0))
19+
};
20+
1521
public ListIndexerGenerator()
1622
{
17-
SupportedMethods = new[]
18-
{
19-
ReflectHelper.GetMethodDefinition(() => Enumerable.ElementAt<object>(null, 0)),
20-
ReflectHelper.GetMethodDefinition(() => Queryable.ElementAt<object>(null, 0))
21-
};
23+
SupportedMethods = _supportedMethods;
2224
}
2325

2426
public bool SupportsMethod(MethodInfo method)
2527
{
26-
return method != null &&
27-
method.Name == "get_Item" &&
28-
(method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>)));
28+
return IsRuntimeMethodSupported(method);
29+
}
30+
31+
public static bool IsMethodSupported(MethodInfo method)
32+
{
33+
if (method.IsGenericMethod)
34+
{
35+
method = method.GetGenericMethodDefinition();
36+
}
37+
38+
return _supportedMethods.Contains(method) || IsRuntimeMethodSupported(method);
2939
}
3040

3141
public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method)
@@ -40,5 +50,12 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
4050

4151
return treeBuilder.Index(collection, index);
4252
}
53+
54+
private static bool IsRuntimeMethodSupported(MethodInfo method)
55+
{
56+
return method != null &&
57+
method.Name == "get_Item" &&
58+
(method.IsMethodOf(typeof(IList)) || method.IsMethodOf(typeof(IList<>)));
59+
}
4360
}
44-
}
61+
}

src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ constant.Value is CallSite site &&
242242

243243
protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
244244
{
245+
// We need to cast the argument when its type is different from Average method return type,
246+
// otherwise the result may be incorrect. In SQL Server avg always returns int
247+
// when the argument is int.
245248
var hqlExpression = VisitExpression(expression.Expression).AsExpression();
246249
hqlExpression = IsCastRequired(expression.Expression, expression.Type, out _)
247250
? (HqlExpression) _hqlTreeBuilder.Cast(hqlExpression, expression.Type)
@@ -269,7 +272,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
269272

270273
protected HqlTreeNode VisitNhSum(NhSumExpression expression)
271274
{
272-
return IsCastRequired(expression.Type, "sum", out _)
275+
return IsCastRequired("sum", expression.Expression, expression.Type)
273276
? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type)
274277
: _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.Sum(VisitExpression(expression.Expression).AsExpression()), expression.Type);
275278
}
@@ -595,7 +598,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
595598
private bool IsCastRequired(Expression expression, System.Type toType, out bool existType)
596599
{
597600
existType = false;
598-
return toType != typeof(object) && IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
601+
return toType != typeof(object) &&
602+
IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
599603
}
600604

601605
private bool IsCastRequired(IType type, IType toType, out bool existType)
@@ -637,59 +641,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
637641
return castTypeName != toCastTypeName;
638642
}
639643

640-
private bool IsCastRequired(System.Type type, string sqlFunctionName, out bool existType)
644+
private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType)
641645
{
642-
if (type == typeof(object))
646+
var argumentType = GetType(argumentExpression);
647+
if (argumentType == null || returnType == typeof(object))
643648
{
644-
existType = false;
645649
return false;
646650
}
647651

648-
var toType = TypeFactory.GetDefaultTypeFor(type);
649-
if (toType == null)
652+
var returnNhType = TypeFactory.GetDefaultTypeFor(returnType);
653+
if (returnNhType == null)
650654
{
651-
existType = false;
652655
return true; // Fallback to the old behavior
653656
}
654657

655-
existType = true;
656658
var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction(sqlFunctionName);
657659
if (sqlFunction == null)
658660
{
659661
return true; // Fallback to the old behavior
660662
}
661663

662-
var fnReturnType = sqlFunction.ReturnType(toType, _parameters.SessionFactory);
663-
return fnReturnType == null || IsCastRequired(fnReturnType, toType, out existType);
664+
var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory);
665+
return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _);
664666
}
665667

666668
private IType GetType(Expression expression)
667669
{
668-
if (!(expression is MemberExpression memberExpression))
669-
{
670-
return expression.Type != typeof(object)
671-
? TypeFactory.GetDefaultTypeFor(expression.Type)
672-
: null;
673-
}
674-
675670
// Try to get the mapped type for the member as it may be a non default one
676-
var entityName = ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, memberExpression, out var memberPath);
677-
if (entityName == null)
678-
{
679-
return TypeFactory.GetDefaultTypeFor(expression.Type); // Not mapped
680-
}
681-
682-
var persister = _parameters.SessionFactory.GetEntityPersister(entityName);
683-
var type = persister.EntityMetamodel.GetIdentifierPropertyType(memberPath);
684-
if (type != null)
685-
{
686-
return type;
687-
}
688-
689-
var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath);
690-
return !index.HasValue
691-
? TypeFactory.GetDefaultTypeFor(expression.Type) // Not mapped
692-
: persister.EntityMetamodel.PropertyTypes[index.Value];
671+
ExpressionsHelper.TryGetEntityName(_parameters.SessionFactory, expression, out _, out var type);
672+
return type ??
673+
(expression.Type != typeof(object)
674+
? TypeFactory.GetDefaultTypeFor(expression.Type)
675+
: null);
693676
}
694677
}
695678
}

src/NHibernate/Tuple/Entity/EntityMetamodel.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public class EntityMetamodel
5151

5252
private readonly Dictionary<string, int?> propertyIndexes = new Dictionary<string, int?>();
5353
private readonly IDictionary<string, IType> _identifierPropertyTypes = new Dictionary<string, IType>();
54+
private readonly IDictionary<string, IType> _propertyTypes = new Dictionary<string, IType>();
5455
private readonly bool hasCollections;
5556
private readonly bool hasMutableProperties;
5657
private readonly bool hasLazyProperties;
@@ -418,15 +419,17 @@ private void MapPropertyToIndex(Mapping.Property prop, int i)
418419

419420
private void MapPropertyToIndex(string path, Mapping.Property prop, int i)
420421
{
421-
propertyIndexes[!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name] = i;
422+
var propPath = !string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name;
423+
propertyIndexes[propPath] = i;
424+
_propertyTypes[propPath] = prop.Type;
422425
if (!(prop.Value is Mapping.Component comp))
423426
{
424427
return;
425428
}
426429

427430
foreach (var subprop in comp.PropertyIterator)
428431
{
429-
MapPropertyToIndex(!string.IsNullOrEmpty(path) ? $"{path}.{prop.Name}" : prop.Name, subprop, i);
432+
MapPropertyToIndex(propPath, subprop, i);
430433
}
431434
}
432435

@@ -572,6 +575,13 @@ internal IType GetIdentifierPropertyType(string memberPath)
572575
return _identifierPropertyTypes.TryGetValue(memberPath, out var propertyType) ? propertyType : null;
573576
}
574577

578+
internal IType GetPropertyType(string memberPath)
579+
{
580+
return _propertyTypes.TryGetValue(memberPath, out var propertyType)
581+
? propertyType
582+
: GetIdentifierPropertyType(memberPath);
583+
}
584+
575585
public bool HasCollections
576586
{
577587
get { return hasCollections; }

0 commit comments

Comments
 (0)