diff --git a/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs new file mode 100644 index 00000000000..a22305a9966 --- /dev/null +++ b/src/NHibernate.Test/Linq/CustomQueryModelRewriterTests.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using NHibernate.Linq.Visitors; +using NUnit.Framework; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Parsing; + +namespace NHibernate.Test.Linq +{ + public class CustomQueryModelRewriterTests : LinqTestCase + { + protected override void Configure(Cfg.Configuration configuration) + { + configuration.Properties[Cfg.Environment.QueryModelRewriterFactory] = typeof(QueryModelRewriterFactory).AssemblyQualifiedName; + } + + [Test] + public void RewriteNullComparison() + { + // This example shows how to use the query model rewriter to + // make radical changes to the query. In this case, we rewrite + // a null comparison (which would translate into a IS NULL) + // into a comparison to "Thomas Hardy" (which translates to a = "Thomas Hardy"). + + var contacts = (from c in db.Customers where c.ContactName == null select c).ToList(); + Assert.Greater(contacts.Count, 0); + Assert.IsTrue(contacts.Select(customer => customer.ContactName).All(c => c == "Thomas Hardy")); + } + + [Serializable] + public class QueryModelRewriterFactory : IQueryModelRewriterFactory + { + public QueryModelVisitorBase CreateVisitor(VisitorParameters parameters) + { + return new CustomVisitor(); + } + } + + public class CustomVisitor : QueryModelVisitorBase + { + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + whereClause.TransformExpressions(new Visitor().VisitExpression); + } + + private class Visitor : ExpressionTreeVisitor + { + protected override Expression VisitBinaryExpression(BinaryExpression expression) + { + if ( + expression.NodeType == ExpressionType.Equal || + expression.NodeType == ExpressionType.NotEqual + ) + { + var left = expression.Left; + var right = expression.Right; + bool reverse = false; + + if (!(left is ConstantExpression) && right is ConstantExpression) + { + var tmp = left; + left = right; + right = tmp; + reverse = true; + } + + var constant = left as ConstantExpression; + + if (constant != null && constant.Value == null) + { + left = Expression.Constant("Thomas Hardy"); + + expression = Expression.MakeBinary( + expression.NodeType, + reverse ? right : left, + reverse ? left : right + ); + } + } + + return base.VisitBinaryExpression(expression); + } + } + } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index c60b89ab659..e9e36bd7329 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -80,6 +80,10 @@ False ..\..\lib\net\nunit.framework.dll + + False + ..\..\lib\net\Remotion.Linq.dll + @@ -512,6 +516,7 @@ + diff --git a/src/NHibernate/Cfg/Environment.cs b/src/NHibernate/Cfg/Environment.cs index f3cb29901f8..273651148ef 100644 --- a/src/NHibernate/Cfg/Environment.cs +++ b/src/NHibernate/Cfg/Environment.cs @@ -171,6 +171,8 @@ public static string Version /// Enable ordering of insert statements for the purpose of more effecient batching. public const string OrderInserts = "order_inserts"; + public const string QueryModelRewriterFactory = "query.query_model_rewriter_factory"; + /// /// If this setting is set to false, exceptions in IInterceptor.BeforeTransactionCompletion bubble to the caller of ITransaction.Commit and abort the commit. /// If this setting is set to true, exceptions in IInterceptor.BeforeTransactionCompletion are ignored and the commit is performed. diff --git a/src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs b/src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs index 13d022940bb..55131818ab4 100644 --- a/src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs +++ b/src/NHibernate/Cfg/Loquacious/DbIntegrationConfigurationProperties.cs @@ -3,6 +3,7 @@ using NHibernate.Connection; using NHibernate.Driver; using NHibernate.Exceptions; +using NHibernate.Linq.Visitors; using NHibernate.Transaction; namespace NHibernate.Cfg.Loquacious @@ -123,6 +124,11 @@ public SchemaAutoAction SchemaAction set { configuration.SetProperty(Environment.Hbm2ddlAuto, value.ToString()); } } + public void QueryModelRewriterFactory() where TFactory : IQueryModelRewriterFactory + { + configuration.SetProperty(Environment.QueryModelRewriterFactory, typeof(TFactory).AssemblyQualifiedName); + } + #endregion } } \ No newline at end of file diff --git a/src/NHibernate/Cfg/Loquacious/IDbIntegrationConfigurationProperties.cs b/src/NHibernate/Cfg/Loquacious/IDbIntegrationConfigurationProperties.cs index ec39020f1f0..c39891a30bb 100644 --- a/src/NHibernate/Cfg/Loquacious/IDbIntegrationConfigurationProperties.cs +++ b/src/NHibernate/Cfg/Loquacious/IDbIntegrationConfigurationProperties.cs @@ -3,6 +3,7 @@ using NHibernate.Connection; using NHibernate.Driver; using NHibernate.Exceptions; +using NHibernate.Linq.Visitors; using NHibernate.Transaction; namespace NHibernate.Cfg.Loquacious @@ -35,5 +36,7 @@ public interface IDbIntegrationConfigurationProperties byte MaximumDepthOfOuterJoinFetching { set; } SchemaAutoAction SchemaAction { set; } + + void QueryModelRewriterFactory() where TFactory : IQueryModelRewriterFactory; } } \ No newline at end of file diff --git a/src/NHibernate/Cfg/Settings.cs b/src/NHibernate/Cfg/Settings.cs index 1f3ec992233..2f8ac968b69 100644 --- a/src/NHibernate/Cfg/Settings.cs +++ b/src/NHibernate/Cfg/Settings.cs @@ -8,6 +8,7 @@ using NHibernate.Exceptions; using NHibernate.Hql; using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; using NHibernate.Transaction; namespace NHibernate.Cfg @@ -129,6 +130,8 @@ public Settings() [Obsolete("This setting is likely to be removed in a future version of NHibernate. The workaround is to catch all exceptions in the IInterceptor implementation.")] public bool IsInterceptorsBeforeTransactionCompletionIgnoreExceptionsEnabled { get; internal set; } + public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; } + #endregion } } \ No newline at end of file diff --git a/src/NHibernate/Cfg/SettingsFactory.cs b/src/NHibernate/Cfg/SettingsFactory.cs index e6ee2f9693e..310b8d874af 100644 --- a/src/NHibernate/Cfg/SettingsFactory.cs +++ b/src/NHibernate/Cfg/SettingsFactory.cs @@ -11,6 +11,7 @@ using NHibernate.Exceptions; using NHibernate.Hql; using NHibernate.Linq.Functions; +using NHibernate.Linq.Visitors; using NHibernate.Transaction; using NHibernate.Util; @@ -288,6 +289,8 @@ public Settings BuildSettings(IDictionary properties) settings.IsMinimalPutsEnabled = useMinimalPuts; // Not ported - JdbcBatchVersionedData + settings.QueryModelRewriterFactory = CreateQueryModelRewriterFactory(properties); + // NHibernate-specific: settings.IsolationLevel = isolation; @@ -379,5 +382,26 @@ private static ITransactionFactory CreateTransactionFactory(IDictionary properties) + { + string className = PropertiesHelper.GetString(Environment.QueryModelRewriterFactory, properties, null); + + if (className == null) + return null; + + log.Info("Query model rewriter factory factory: " + className); + + try + { + return + (IQueryModelRewriterFactory) + Environment.BytecodeProvider.ObjectsFactory.CreateInstance(ReflectHelper.ClassForName(className)); + } + catch (Exception cnfe) + { + throw new HibernateException("could not instantiate IQueryModelRewriterFactory: " + className, cnfe); + } + } } } diff --git a/src/NHibernate/Linq/Visitors/IQueryModelRewriterFactory.cs b/src/NHibernate/Linq/Visitors/IQueryModelRewriterFactory.cs new file mode 100644 index 00000000000..73cd2d2e04b --- /dev/null +++ b/src/NHibernate/Linq/Visitors/IQueryModelRewriterFactory.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Remotion.Linq; + +namespace NHibernate.Linq.Visitors +{ + public interface IQueryModelRewriterFactory + { + QueryModelVisitorBase CreateVisitor(VisitorParameters parameters); + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 49f2bd15c3e..9ac7c237acc 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -57,6 +57,16 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer // Move OrderBy clauses to end MoveOrderByToEndRewriter.ReWrite(queryModel); + // Give a rewriter provided by the session factory a chance to + // rewrite the query. + var rewriterFactory = parameters.SessionFactory.Settings.QueryModelRewriterFactory; + if (rewriterFactory != null) + { + var customVisitor = rewriterFactory.CreateVisitor(parameters); + if (customVisitor != null) + customVisitor.VisitQueryModel(queryModel); + } + // rewrite any operators that should be applied on the outer query // by flattening out the sub-queries that they are located in var result = ResultOperatorRewriter.Rewrite(queryModel); diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 54fa348f4e2..36111c77a25 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -296,6 +296,7 @@ +