Skip to content

Commit 6d76720

Browse files
committed
Extract GetParameterType and GetCandidateTypes methods
1 parent 0858dce commit 6d76720

File tree

1 file changed

+64
-43
lines changed

1 file changed

+64
-43
lines changed

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -112,60 +112,81 @@ internal static void SetParameterTypes(
112112
continue;
113113
}
114114

115-
var parameterRelatedExpressions = new List<Expression>();
116-
foreach (var expression in constantExpressions)
115+
namedParameter.Type = GetParameterType(sessionFactory, constantExpressions, visitor, namedParameter);
116+
}
117+
}
118+
119+
private static HashSet<IType> GetCandidateTypes(
120+
ISessionFactoryImplementor sessionFactory,
121+
IEnumerable<ConstantExpression> constantExpressions,
122+
ConstantTypeLocatorVisitor visitor)
123+
{
124+
var parameterRelatedExpressions = new List<Expression>();
125+
foreach (var expression in constantExpressions)
126+
{
127+
if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions))
117128
{
118-
if (visitor.RelatedExpressions.TryGetValue(expression, out var relatedExpressions))
119-
{
120-
parameterRelatedExpressions.AddRange(relatedExpressions);
121-
}
129+
parameterRelatedExpressions.AddRange(relatedExpressions);
122130
}
131+
}
123132

124-
var candidateTypes = new HashSet<IType>();
125-
// In order to get the actual type we have to check first the related member expressions, as
126-
// an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string.
127-
// By getting the type from a related member expression we also get the correct length in case of StringType
128-
// or precision when having a DecimalType.
129-
foreach (var relatedExpression in parameterRelatedExpressions)
133+
var candidateTypes = new HashSet<IType>();
134+
// In order to get the actual type we have to check first the related member expressions, as
135+
// an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string.
136+
// By getting the type from a related member expression we also get the correct length in case of StringType
137+
// or precision when having a DecimalType.
138+
foreach (var relatedExpression in parameterRelatedExpressions)
139+
{
140+
if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _))
130141
{
131-
if (TryGetMappedType(sessionFactory, relatedExpression, out var candidateType, out _, out _, out _))
142+
if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
132143
{
133-
if (candidateType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
134-
{
135-
var collection = (IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory);
136-
candidateType = collection.ElementType;
137-
}
138-
139-
candidateTypes.Add(candidateType);
144+
var collection =
145+
(IQueryableCollection) ((IAssociationType) candidateType).GetAssociatedJoinable(sessionFactory);
146+
candidateType = collection.ElementType;
140147
}
141-
}
142148

143-
// All constant expressions have the same type/value
144-
var constantExpression = constantExpressions.First();
145-
var constantType = constantExpression.Type.UnwrapIfNullable();
146-
IType type = null;
147-
if (
148-
candidateTypes.Count == 1 &&
149-
// When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
150-
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
151-
!(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) && FloatingPointNumericTypes.Contains(constantType))
152-
)
153-
{
154-
type = candidateTypes.FirstOrDefault();
149+
candidateTypes.Add(candidateType);
155150
}
151+
}
156152

157-
// No related MemberExpressions was found, guess the type by value or its type when null.
158-
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
159-
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
160-
if (type == null)
161-
{
162-
type = constantExpression.Value != null
163-
? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection)
164-
: ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection);
165-
}
153+
return candidateTypes;
154+
}
166155

167-
namedParameter.Type = type;
156+
private static IType GetParameterType(
157+
ISessionFactoryImplementor sessionFactory,
158+
HashSet<ConstantExpression> constantExpressions,
159+
ConstantTypeLocatorVisitor visitor,
160+
NamedParameter namedParameter)
161+
{
162+
var candidateTypes = GetCandidateTypes(sessionFactory, constantExpressions, visitor);
163+
164+
// All constant expressions have the same type/value
165+
var constantExpression = constantExpressions.First();
166+
var constantType = constantExpression.Type.UnwrapIfNullable();
167+
IType type = null;
168+
if (
169+
candidateTypes.Count == 1 &&
170+
// When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
171+
// and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
172+
!(candidateTypes.Any(t => IntegralNumericTypes.Contains(t.ReturnedClass)) &&
173+
FloatingPointNumericTypes.Contains(constantType))
174+
)
175+
{
176+
type = candidateTypes.FirstOrDefault();
177+
}
178+
179+
// No related MemberExpressions was found, guess the type by value or its type when null.
180+
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
181+
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
182+
if (type == null)
183+
{
184+
type = constantExpression.Value != null
185+
? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection)
186+
: ParameterHelper.TryGuessType(constantType, sessionFactory, namedParameter.IsCollection);
168187
}
188+
189+
return type;
169190
}
170191

171192
private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor

0 commit comments

Comments
 (0)