Skip to content

Commit 9d0bbb5

Browse files
committed
Avoid HashSet creation for candidate type calculation
1 parent 371b8df commit 9d0bbb5

File tree

1 file changed

+21
-40
lines changed

1 file changed

+21
-40
lines changed

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -115,62 +115,39 @@ internal static void SetParameterTypes(
115115
}
116116
}
117117

118-
private static HashSet<IType> GetCandidateTypes(
118+
private static IType GetCandidateType(
119119
ISessionFactoryImplementor sessionFactory,
120120
IEnumerable<ConstantExpression> constantExpressions,
121121
ConstantTypeLocatorVisitor visitor)
122122
{
123-
var candidateTypes = new HashSet<IType>();
123+
IType candidateType = null;
124124
foreach (var expression in constantExpressions)
125125
{
126126
// In order to get the actual type we have to check first the related member expressions, as
127127
// an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string.
128128
// By getting the type from a related member expression we also get the correct length in case of StringType
129129
// or precision when having a DecimalType.
130-
if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions))
130+
if (!visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions))
131+
continue;
132+
foreach (var relatedExpression in relatedExpressions)
131133
{
132-
foreach (var relatedExpression in relatedExpressions)
133-
{
134-
if (ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _))
135-
{
136-
if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
137-
{
138-
var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory);
139-
candidateType = collection.ElementType;
140-
}
134+
if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _))
135+
continue;
141136

142-
candidateTypes.Add(candidateType);
143-
}
137+
if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
138+
{
139+
var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory);
140+
mappedType = collection.ElementType;
144141
}
145-
}
146-
}
147-
148-
return candidateTypes;
149-
}
150-
151-
private static bool GetCandidateType(
152-
ISessionFactoryImplementor sessionFactory,
153-
IEnumerable<ConstantExpression> constantExpressions,
154-
ConstantTypeLocatorVisitor visitor,
155-
System.Type constantType,
156-
out IType candidateType)
157-
{
158-
var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor);
159-
if (candidateTypes.Count == 1)
160-
{
161-
candidateType = candidateTypes.First();
162142

163-
// When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
164-
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
165-
if (!IntegralNumericTypes.Contains(candidateType.ReturnedClass) ||
166-
!FloatingPointNumericTypes.Contains(constantType))
167-
{
168-
return true;
143+
if (candidateType == null)
144+
candidateType = mappedType;
145+
else if (!candidateType.Equals(mappedType))
146+
return null;
169147
}
170148
}
171149

172-
candidateType = null;
173-
return false;
150+
return candidateType;
174151
}
175152

176153
private static IType GetParameterType(
@@ -183,7 +160,11 @@ private static IType GetParameterType(
183160
// All constant expressions have the same type/value
184161
var constantExpression = constantExpressions.First();
185162
var constantType = constantExpression.Type.UnwrapIfNullable();
186-
if (GetCandidateType(sessionFactory, constantExpressions, visitor, constantType, out var candidateType))
163+
var candidateType = GetCandidateType(sessionFactory, constantExpressions, visitor);
164+
if (candidateType != null &&
165+
// When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
166+
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
167+
!(FloatingPointNumericTypes.Contains(constantType) && IntegralNumericTypes.Contains(candidateType.ReturnedClass)))
187168
{
188169
return candidateType;
189170
}

0 commit comments

Comments
 (0)