Skip to content

Commit b40fc9d

Browse files
committed
NH-3929 - Reimplement NH-3904
1 parent a37a839 commit b40fc9d

File tree

4 files changed

+81
-29
lines changed

4 files changed

+81
-29
lines changed

src/NHibernate.Test/NHSpecificTest/EntityWithUserTypeCanHaveLinqGenerators/Fixture.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ protected override string MappingsAssembly
2222

2323
protected override void Configure(Configuration configuration)
2424
{
25-
base.Configure(configuration);
2625
configuration.LinqToHqlGeneratorsRegistry<EntityWithUserTypePropertyGeneratorsRegistry>();
2726
}
2827

@@ -81,11 +80,11 @@ public void EqualityWorksForUserType()
8180
.Where(x => x.Example == newItem)
8281
.ToList();
8382

84-
Assert.AreEqual(1, entities.Count);
83+
Assert.That(entities.Count, Is.EqualTo(1));
8584
}
8685
}
8786

88-
[Test, Ignore("Not implemented yet")]
87+
[Test]
8988
public void LinqMethodWorksForUserType()
9089
{
9190
using (var session = OpenSession())
@@ -96,7 +95,7 @@ public void LinqMethodWorksForUserType()
9695
.Where(x => x.Example.IsEquivalentTo(newItem))
9796
.ToList();
9897

99-
Assert.AreEqual(2, entities.Count);
98+
Assert.That(entities.Count, Is.EqualTo(2));
10099
}
101100
}
102101

@@ -111,7 +110,7 @@ public void EqualityWorksForExplicitUserType()
111110
.Where(x => x.Example == newItem.MappedAs(NHibernateUtil.Custom(typeof(ExampleUserType))))
112111
.ToList();
113112

114-
Assert.AreEqual(1, entities.Count);
113+
Assert.That(entities.Count, Is.EqualTo(1));
115114
}
116115
}
117116

@@ -126,7 +125,7 @@ public void LinqMethodWorksForExplicitUserType()
126125
.Where(x => x.Example.IsEquivalentTo(newItem.MappedAs(NHibernateUtil.Custom(typeof(ExampleUserType)))))
127126
.ToList();
128127

129-
Assert.AreEqual(2, entities.Count);
128+
Assert.That(entities.Count, Is.EqualTo(2));
130129
}
131130
}
132131

@@ -140,7 +139,7 @@ public void LinqMethodWorksForStandardStringProperty()
140139
.Where(x => x.Name == "Bob")
141140
.ToList();
142141

143-
Assert.AreEqual(1, entities.Count);
142+
Assert.That(entities.Count, Is.EqualTo(1));
144143
}
145144
}
146145

@@ -155,7 +154,7 @@ public void CanQueryWithHql()
155154
q.SetParameter("exampleItem", newItem);
156155
var entities = q.List<EntityWithUserTypeProperty>();
157156

158-
Assert.AreEqual(1, entities.Count);
157+
Assert.That(entities.Count, Is.EqualTo(1));
159158
}
160159
}
161160
}

src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
using System.Linq.Expressions;
55
using System.Reflection;
66
using NHibernate.Engine;
7+
using NHibernate.Metadata;
78
using NHibernate.Param;
89
using NHibernate.Type;
10+
using NHibernate.UserTypes;
911
using Remotion.Linq.Parsing;
1012

1113
namespace NHibernate.Linq.Visitors
@@ -26,9 +28,25 @@ public class ExpressionParameterVisitor : ExpressionTreeVisitor
2628
ReflectionHelper.GetMethodDefinition(() => Enumerable.Take<object>(null, 0)),
2729
};
2830

31+
readonly List<IType> _allMappedCustomTypes;
32+
2933
public ExpressionParameterVisitor(ISessionFactoryImplementor sessionFactory)
3034
{
3135
_sessionFactory = sessionFactory;
36+
_allMappedCustomTypes = GetAllPropertyTypes(sessionFactory)
37+
.Where(t => t is CustomType || t is CompositeCustomType)
38+
.Distinct()
39+
.ToList();
40+
}
41+
42+
static IEnumerable<IType> GetAllPropertyTypes(ISessionFactory sessionFactory)
43+
{
44+
foreach (var c in sessionFactory.GetAllClassMetadata().Values)
45+
{
46+
yield return c.IdentifierType;
47+
foreach (var propertyType in c.PropertyTypes)
48+
yield return propertyType;
49+
}
3250
}
3351

3452
public static IDictionary<ConstantExpression, NamedParameter> Visit(Expression expression, ISessionFactoryImplementor sessionFactory)
@@ -87,12 +105,8 @@ protected override Expression VisitConstantExpression(ConstantExpression express
87105
{
88106
// We use null for the type to indicate that the caller should let HQL figure it out.
89107
object value = expression.Value;
90-
IType type = null;
91108

92-
// We have a bit more information about the null parameter value.
93-
// Figure out a type so that HQL doesn't break on the null. (Related to NH-2430)
94-
if (expression.Value == null)
95-
type = NHibernateUtil.GuessType(expression.Type);
109+
var type = GuessType(expression.Type);
96110

97111
// Constant characters should be sent as strings
98112
if (expression.Type == typeof(char))
@@ -111,6 +125,39 @@ protected override Expression VisitConstantExpression(ConstantExpression express
111125
return base.VisitConstantExpression(expression);
112126
}
113127

128+
/// <summary>
129+
/// Guesses the <see cref="IType"/> from the <see cref="System.Type"/>.
130+
/// </summary>
131+
/// <param name="clazz">The <see cref="System.Type"/> to guess the <see cref="IType"/> of.</param>
132+
/// <returns>An <see cref="IType"/> for the <see cref="System.Type"/>.</returns>
133+
/// <exception cref="ArgumentNullException">
134+
/// Thrown when the <c>clazz</c> is null because the <see cref="IType"/>
135+
/// can't be guess from a null type.
136+
/// </exception>
137+
IType GuessType(System.Type clazz)
138+
{
139+
if (clazz == null)
140+
throw new ArgumentNullException("clazz", "The IType can not be guessed for a null value.");
141+
142+
var typename = clazz.AssemblyQualifiedName;
143+
var heuristicType = TypeFactory.HeuristicType(typename);
144+
var serializable = heuristicType is SerializableType;
145+
if (heuristicType != null && !serializable)
146+
return heuristicType;
147+
148+
if (_sessionFactory.TryGetEntityPersister(clazz.FullName) != null)
149+
return NHibernateUtil.Entity(clazz);
150+
151+
var customType = _allMappedCustomTypes.Find(x => x.ReturnedClass.IsAssignableFrom(clazz));
152+
if (customType != null)
153+
return customType;
154+
155+
if (serializable)
156+
return heuristicType;
157+
158+
return null;
159+
}
160+
114161
private static bool IsNullObject(ConstantExpression expression)
115162
{
116163
return expression.Type == typeof(Object) && expression.Value == null;

src/NHibernate/Type/CompositeCustomType.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace NHibernate.Type
1414
{
1515
[Serializable]
16-
public class CompositeCustomType : AbstractType, IAbstractComponentType
16+
public class CompositeCustomType : AbstractType, IAbstractComponentType, IEquatable<CompositeCustomType>
1717
{
1818
private readonly ICompositeUserType userType;
1919
private readonly string name;
@@ -206,14 +206,17 @@ public override string ToLoggableString(object value, ISessionFactoryImplementor
206206
return value == null ? "null" : value.ToString();
207207
}
208208

209-
public override bool Equals(object obj)
209+
public bool Equals(CompositeCustomType other)
210210
{
211-
if (!base.Equals(obj))
212-
{
213-
return false;
214-
}
211+
if (ReferenceEquals(null, other)) return false;
212+
if (ReferenceEquals(this, other)) return true;
215213

216-
return ((CompositeCustomType) obj).userType.GetType() == userType.GetType();
214+
return other.userType.GetType() == userType.GetType();
215+
}
216+
217+
public override bool Equals(object obj)
218+
{
219+
return Equals(obj as CompositeCustomType);
217220
}
218221

219222
public override int GetHashCode()
@@ -274,4 +277,4 @@ public override bool[] ToColumnNullness(object value, IMapping mapping)
274277
}
275278

276279
}
277-
}
280+
}

src/NHibernate/Type/CustomType.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace NHibernate.Type
1717
/// <seealso cref="IUserType"/>
1818
/// </summary>
1919
[Serializable]
20-
public class CustomType : AbstractType, IDiscriminatorType, IVersionType
20+
public class CustomType : AbstractType, IDiscriminatorType, IVersionType, IEquatable<CustomType>
2121
{
2222
private readonly IUserType userType;
2323
private readonly string name;
@@ -152,14 +152,17 @@ public override bool IsMutable
152152
get { return userType.IsMutable; }
153153
}
154154

155-
public override bool Equals(object obj)
155+
public bool Equals(CustomType obj)
156156
{
157-
if (!base.Equals(obj))
158-
{
159-
return false;
160-
}
157+
if (ReferenceEquals(null, obj)) return false;
158+
if (ReferenceEquals(this, obj)) return true;
161159

162-
return ((CustomType) obj).userType.GetType() == userType.GetType();
160+
return obj.userType.GetType() == userType.GetType();
161+
}
162+
163+
public override bool Equals(object obj)
164+
{
165+
return Equals(obj as CustomType);
163166
}
164167

165168
public override int GetHashCode()
@@ -271,4 +274,4 @@ public virtual string ToXMLString(object value, ISessionFactoryImplementor facto
271274
}
272275
}
273276
}
274-
}
277+
}

0 commit comments

Comments
 (0)