@@ -115,62 +115,39 @@ internal static void SetParameterTypes(
115
115
}
116
116
}
117
117
118
- private static HashSet < IType > GetCandidateTypes (
118
+ private static IType GetCandidateType (
119
119
ISessionFactoryImplementor sessionFactory ,
120
120
IEnumerable < ConstantExpression > constantExpressions ,
121
121
ConstantTypeLocatorVisitor visitor )
122
122
{
123
- var candidateTypes = new HashSet < IType > ( ) ;
123
+ IType candidateType = null ;
124
124
foreach ( var expression in constantExpressions )
125
125
{
126
126
// In order to get the actual type we have to check first the related member expressions, as
127
127
// an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string.
128
128
// By getting the type from a related member expression we also get the correct length in case of StringType
129
129
// or precision when having a DecimalType.
130
- if ( visitor . RelatedExpressions . TryGetValue ( expression , out var relatedExpressions ) )
130
+ if ( ! visitor . RelatedExpressions . TryGetValue ( expression , out var relatedExpressions ) )
131
+ continue ;
132
+ foreach ( var relatedExpression in relatedExpressions )
131
133
{
132
- foreach ( var relatedExpression in relatedExpressions )
133
- {
134
- if ( ExpressionsHelper . TryGetMappedType ( sessionFactory , relatedExpression , out var candidateType , out _ , out _ , out _ ) )
135
- {
136
- if ( candidateType . IsAssociationType && visitor . SequenceSelectorExpressions . Contains ( relatedExpression ) )
137
- {
138
- var collection = ( IQueryableCollection ) ( ( IAssociationType ) candidateType ) . GetAssociatedJoinable ( sessionFactory ) ;
139
- candidateType = collection . ElementType ;
140
- }
134
+ if ( ! ExpressionsHelper . TryGetMappedType ( sessionFactory , relatedExpression , out var mappedType , out _ , out _ , out _ ) )
135
+ continue ;
141
136
142
- candidateTypes . Add ( candidateType ) ;
143
- }
137
+ if ( mappedType . IsAssociationType && visitor . SequenceSelectorExpressions . Contains ( relatedExpression ) )
138
+ {
139
+ var collection = ( IQueryableCollection ) ( ( IAssociationType ) mappedType ) . GetAssociatedJoinable ( sessionFactory ) ;
140
+ mappedType = collection . ElementType ;
144
141
}
145
- }
146
- }
147
-
148
- return candidateTypes ;
149
- }
150
-
151
- private static bool GetCandidateType (
152
- ISessionFactoryImplementor sessionFactory ,
153
- IEnumerable < ConstantExpression > constantExpressions ,
154
- ConstantTypeLocatorVisitor visitor ,
155
- System . Type constantType ,
156
- out IType candidateType )
157
- {
158
- var candidateTypes = GetCandidateTypes ( sessionFactory , constantExpressions , visitor ) ;
159
- if ( candidateTypes . Count == 1 )
160
- {
161
- candidateType = candidateTypes . First ( ) ;
162
142
163
- // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
164
- // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
165
- if ( ! IntegralNumericTypes . Contains ( candidateType . ReturnedClass ) ||
166
- ! FloatingPointNumericTypes . Contains ( constantType ) )
167
- {
168
- return true ;
143
+ if ( candidateType == null )
144
+ candidateType = mappedType ;
145
+ else if ( ! candidateType . Equals ( mappedType ) )
146
+ return null ;
169
147
}
170
148
}
171
149
172
- candidateType = null ;
173
- return false ;
150
+ return candidateType ;
174
151
}
175
152
176
153
private static IType GetParameterType (
@@ -183,7 +160,11 @@ private static IType GetParameterType(
183
160
// All constant expressions have the same type/value
184
161
var constantExpression = constantExpressions . First ( ) ;
185
162
var constantType = constantExpression . Type . UnwrapIfNullable ( ) ;
186
- if ( GetCandidateType ( sessionFactory , constantExpressions , visitor , constantType , out var candidateType ) )
163
+ var candidateType = GetCandidateType ( sessionFactory , constantExpressions , visitor ) ;
164
+ if ( candidateType != null &&
165
+ // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
166
+ // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
167
+ ! ( FloatingPointNumericTypes . Contains ( constantType ) && IntegralNumericTypes . Contains ( candidateType . ReturnedClass ) ) )
187
168
{
188
169
return candidateType ;
189
170
}
0 commit comments