Skip to content

Commit 5fcd85b

Browse files
committed
NH-3499 - Allow custom QueryModelVisitorBase to be provided through the session factory
Implements a new parameter on the session factory to provide a IQueryModelRewriterFactory which is used in QueryModelVisitor to allow the query model to be rewritten.
1 parent 612d1b6 commit 5fcd85b

File tree

8 files changed

+151
-2
lines changed

8 files changed

+151
-2
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Text;
6+
using NHibernate.Linq.Visitors;
7+
using NUnit.Framework;
8+
using Remotion.Linq;
9+
using Remotion.Linq.Clauses;
10+
using Remotion.Linq.Parsing;
11+
using SharpTestsEx;
12+
13+
namespace NHibernate.Test.Linq
14+
{
15+
public class CustomQueryModelRewriterTests : LinqTestCase
16+
{
17+
protected override void Configure(Cfg.Configuration configuration)
18+
{
19+
configuration.Properties[Cfg.Environment.QueryModelRewriterFactory] = typeof(QueryModelRewriterFactory).AssemblyQualifiedName;
20+
}
21+
22+
[Test]
23+
public void RewriteNullComparison()
24+
{
25+
// This example shows how to use the query model rewriter to
26+
// make radical changes to the query. In this case, we rewrite
27+
// a null comparison (which would translate into a IS NULL)
28+
// into a comparison to "Thomas Hardy" (which translates to a = "Thomas Hardy").
29+
30+
var contacts = (from c in db.Customers where c.ContactName == null select c).ToList();
31+
contacts.Count.Should().Be.GreaterThan(0);
32+
contacts.Select(customer => customer.ContactName).All(c => c.Satisfy(customer => customer == "Thomas Hardy"));
33+
}
34+
35+
[Serializable]
36+
public class QueryModelRewriterFactory : IQueryModelRewriterFactory
37+
{
38+
public QueryModelVisitorBase CreateVisitor(VisitorParameters parameters)
39+
{
40+
return new CustomVisitor();
41+
}
42+
}
43+
44+
public class CustomVisitor : QueryModelVisitorBase
45+
{
46+
public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index)
47+
{
48+
whereClause.TransformExpressions(new Visitor().VisitExpression);
49+
}
50+
51+
private class Visitor : ExpressionTreeVisitor
52+
{
53+
protected override Expression VisitBinaryExpression(BinaryExpression expression)
54+
{
55+
if (
56+
expression.NodeType == ExpressionType.Equal ||
57+
expression.NodeType == ExpressionType.NotEqual
58+
)
59+
{
60+
var left = expression.Left;
61+
var right = expression.Right;
62+
bool reverse = false;
63+
64+
if (!(left is ConstantExpression) && right is ConstantExpression)
65+
{
66+
var tmp = left;
67+
left = right;
68+
right = tmp;
69+
reverse = true;
70+
}
71+
72+
var constant = left as ConstantExpression;
73+
74+
if (constant != null && constant.Value == null)
75+
{
76+
left = Expression.Constant("Thomas Hardy");
77+
78+
expression = Expression.MakeBinary(
79+
expression.NodeType,
80+
reverse ? right : left,
81+
reverse ? left : right
82+
);
83+
}
84+
}
85+
86+
return base.VisitBinaryExpression(expression);
87+
}
88+
}
89+
}
90+
}
91+
}

src/NHibernate.Test/NHibernate.Test.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
<SpecificVersion>False</SpecificVersion>
8181
<HintPath>..\..\lib\net\nunit.framework.dll</HintPath>
8282
</Reference>
83+
<Reference Include="Remotion.Linq, Version=1.13.171.1, Culture=neutral, PublicKeyToken=fee00910d6e5f53b, processorArchitecture=MSIL">
84+
<SpecificVersion>False</SpecificVersion>
85+
<HintPath>..\..\lib\net\Remotion.Linq.dll</HintPath>
86+
</Reference>
8387
<Reference Include="SharpTestsEx.NUnit, Version=1.0.0.0, Culture=neutral, PublicKeyToken=8c60d8070630b1c1, processorArchitecture=MSIL">
8488
<SpecificVersion>False</SpecificVersion>
8589
<HintPath>..\..\lib\net\SharpTestsEx.NUnit.dll</HintPath>
@@ -507,6 +511,7 @@
507511
<Compile Include="Linq\ByMethod\GetValueOrDefaultTests.cs" />
508512
<Compile Include="Linq\CasingTest.cs" />
509513
<Compile Include="Linq\CharComparisonTests.cs" />
514+
<Compile Include="Linq\CustomQueryModelRewriterTests.cs" />
510515
<Compile Include="Linq\DateTimeTests.cs" />
511516
<Compile Include="Linq\LoggingTests.cs" />
512517
<Compile Include="Linq\QueryTimeoutTests.cs" />

src/NHibernate/Cfg/Environment.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ public static string Version
171171
/// <summary> Enable ordering of insert statements for the purpose of more effecient batching.</summary>
172172
public const string OrderInserts = "order_inserts";
173173

174+
public const string QueryModelRewriterFactory = "query.query_model_rewriter_factory";
175+
174176
private static readonly Dictionary<string, string> GlobalProperties;
175177

176178
private static IBytecodeProvider BytecodeProviderInstance;

src/NHibernate/Cfg/Settings.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using NHibernate.Exceptions;
88
using NHibernate.Hql;
99
using NHibernate.Linq.Functions;
10+
using NHibernate.Linq.Visitors;
1011
using NHibernate.Transaction;
1112

1213
namespace NHibernate.Cfg
@@ -125,6 +126,8 @@ public Settings()
125126
/// </summary>
126127
public ILinqToHqlGeneratorsRegistry LinqToHqlGeneratorsRegistry { get; internal set; }
127128

129+
public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; }
130+
128131
#endregion
129132
}
130133
}

src/NHibernate/Cfg/SettingsFactory.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using NHibernate.Exceptions;
1212
using NHibernate.Hql;
1313
using NHibernate.Linq.Functions;
14+
using NHibernate.Linq.Visitors;
1415
using NHibernate.Transaction;
1516
using NHibernate.Util;
1617

@@ -280,8 +281,10 @@ public Settings BuildSettings(IDictionary<string, string> properties)
280281
settings.IsSecondLevelCacheEnabled = useSecondLevelCache;
281282
settings.CacheRegionPrefix = cacheRegionPrefix;
282283
settings.IsMinimalPutsEnabled = useMinimalPuts;
283-
// Not ported - JdbcBatchVersionedData
284-
284+
// Not ported - JdbcBatchVersionedData
285+
286+
settings.QueryModelRewriterFactory = CreateQueryModelRewriterFactory(properties);
287+
285288
// NHibernate-specific:
286289
settings.IsolationLevel = isolation;
287290

@@ -373,5 +376,26 @@ private static ITransactionFactory CreateTransactionFactory(IDictionary<string,
373376
throw new HibernateException("could not instantiate TransactionFactory: " + className, cnfe);
374377
}
375378
}
379+
380+
private static IQueryModelRewriterFactory CreateQueryModelRewriterFactory(IDictionary<string, string> properties)
381+
{
382+
string className = PropertiesHelper.GetString(Environment.QueryModelRewriterFactory, properties, null);
383+
384+
if (className == null)
385+
return null;
386+
387+
log.Info("Query model rewriter factory factory: " + className);
388+
389+
try
390+
{
391+
return
392+
(IQueryModelRewriterFactory)
393+
Environment.BytecodeProvider.ObjectsFactory.CreateInstance(ReflectHelper.ClassForName(className));
394+
}
395+
catch (Exception cnfe)
396+
{
397+
throw new HibernateException("could not instantiate IQueryModelRewriterFactory: " + className, cnfe);
398+
}
399+
}
376400
}
377401
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Remotion.Linq;
6+
7+
namespace NHibernate.Linq.Visitors
8+
{
9+
public interface IQueryModelRewriterFactory
10+
{
11+
QueryModelVisitorBase CreateVisitor(VisitorParameters parameters);
12+
}
13+
}

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
5757
// Move OrderBy clauses to end
5858
MoveOrderByToEndRewriter.ReWrite(queryModel);
5959

60+
// Give a rewriter provided by the session factory a chance to
61+
// rewrite the query.
62+
var rewriterFactory = parameters.SessionFactory.Settings.QueryModelRewriterFactory;
63+
if (rewriterFactory != null)
64+
{
65+
var customVisitor = rewriterFactory.CreateVisitor(parameters);
66+
if (customVisitor != null)
67+
customVisitor.VisitQueryModel(queryModel);
68+
}
69+
6070
// rewrite any operators that should be applied on the outer query
6171
// by flattening out the sub-queries that they are located in
6272
var result = ResultOperatorRewriter.Rewrite(queryModel);

src/NHibernate/NHibernate.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
<Compile Include="Linq\NestedSelects\Tuple.cs" />
297297
<Compile Include="Linq\NestedSelects\SelectClauseRewriter.cs" />
298298
<Compile Include="Linq\NestedSelects\ExpressionHolder.cs" />
299+
<Compile Include="Linq\Visitors\IQueryModelRewriterFactory.cs" />
299300
<Compile Include="Linq\Visitors\LeftJoinRewriter.cs" />
300301
<Compile Include="Linq\Functions\CompareGenerator.cs" />
301302
<Compile Include="Linq\ExpressionTransformers\SimplifyCompareTransformer.cs" />

0 commit comments

Comments
 (0)