Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
using NHibernate.Linq.Visitors;
Expand All @@ -33,6 +34,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration)
configuration.LinqToHqlGeneratorsRegistry<MyLinqToHqlGeneratorsRegistry>();
}

[Test]
public async Task CanUseObjectEqualsAsync()
{
var users = await (db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToListAsync());
Assert.That(users.Count, Is.EqualTo(2));
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test]
public async Task CanUseMyCustomExtensionAsync()
{
Expand Down
25 changes: 25 additions & 0 deletions src/NHibernate.Test/Linq/CustomExtensionsExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Reflection;
using System.Text.RegularExpressions;
using NHibernate.Cfg;
using NHibernate.DomainModel.Northwind.Entities;
using NHibernate.Hql.Ast;
using NHibernate.Linq.Functions;
using NHibernate.Linq.Visitors;
Expand All @@ -30,6 +31,7 @@ public MyLinqToHqlGeneratorsRegistry():base()
{
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)),
new IsLikeGenerator());
RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator());
}
}

Expand All @@ -48,6 +50,21 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
}
}

public class ObjectEqualsGenerator : BaseHqlGeneratorForMethod
{
public ObjectEqualsGenerator()
{
SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => new object().Equals(null)) };
}

public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
ReadOnlyCollection<Expression> arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor)
{
return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(),
visitor.Visit(arguments[0]).AsExpression());
}
}

[TestFixture]
public class CustomExtensionsExample : LinqTestCase
{
Expand All @@ -56,6 +73,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration)
configuration.LinqToHqlGeneratorsRegistry<MyLinqToHqlGeneratorsRegistry>();
}

[Test]
public void CanUseObjectEquals()
{
var users = db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToList();
Assert.That(users.Count, Is.EqualTo(2));
Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True);
}

[Test]
public void CanUseMyCustomExtension()
{
Expand Down
24 changes: 24 additions & 0 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ private static IType GetParameterType(
return candidateType;
}

if (visitor.NotGuessableConstants.Contains(constantExpression))
{
return null;
}

// No related MemberExpressions was found, guess the type by value or its type when null.
// When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam))
// do not change the parameter type, but instead cast the parameter when comparing with different column types.
Expand All @@ -166,10 +171,13 @@ private static IType GetParameterType(

private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
{
private bool _hqlGenerator;
private readonly bool _removeMappedAsCalls;
private readonly System.Type _targetType;
private readonly IDictionary<ConstantExpression, NamedParameter> _parameters;
private readonly ISessionFactoryImplementor _sessionFactory;
private readonly ILinqToHqlGeneratorsRegistry _functionRegistry;
public readonly HashSet<ConstantExpression> NotGuessableConstants = new HashSet<ConstantExpression>();
public readonly Dictionary<ConstantExpression, IType> ConstantExpressions =
new Dictionary<ConstantExpression, IType>();
public readonly Dictionary<NamedParameter, HashSet<ConstantExpression>> ParameterConstants =
Expand All @@ -187,6 +195,7 @@ public ConstantTypeLocatorVisitor(
_targetType = targetType;
_sessionFactory = sessionFactory;
_parameters = parameters;
_functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry;
}

protected override Expression VisitBinary(BinaryExpression node)
Expand Down Expand Up @@ -257,6 +266,16 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
return node;
}

// For hql method generators we do not want to guess the parameter type here, let hql logic figure it out.
if (_functionRegistry.TryGetGenerator(node.Method, out _))
{
var origHqlGenerator = _hqlGenerator;
_hqlGenerator = true;
var expression = base.VisitMethodCall(node);
_hqlGenerator = origHqlGenerator;
return expression;
}

return base.VisitMethodCall(node);
}

Expand All @@ -267,6 +286,11 @@ protected override Expression VisitConstant(ConstantExpression node)
return node;
}

if (_hqlGenerator)
{
NotGuessableConstants.Add(node);
}

RelatedExpressions.Add(node, new HashSet<Expression>());
ConstantExpressions.Add(node, null);
if (!ParameterConstants.TryGetValue(param, out var set))
Expand Down