diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index bcfcf21a2ea..257b2b66469 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Exceptions; +using NHibernate.Linq; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -1620,6 +1622,82 @@ public void DLinq2C() Assert.That(!q.Any(orderid => withNullShippingDate.Contains(orderid))); } + + [Test] + public void CanSpecifyParameterTypeOnRestriction() + { + //NH-2401 + using (this.session.BeginTransaction()) + { + //this should work perfectly, even if doesn't return anything (sanity check) + (from o in this.db.Orders where o.ShippingDate == DateTime.Now select o).FirstOrDefault(); + + //this one should throw an exception because of invalid case + //note: cannot use Assert.Throws because we need to look at the actual exception + try + { + (from o in this.db.Orders where o.ShippingDate.MappedAs(NHibernateUtil.String) == DateTime.Now select o).FirstOrDefault(); + } + catch (GenericADOException ex) + { + Assert.IsInstanceOf(ex.InnerException); + Assert.AreEqual("Unable to cast object of type 'System.DateTime' to type 'System.String'.", ex.InnerException.Message); + } + } + } + + [Test] + public void CanSpecifyParameterTypeAndToStringOnRestriction() + { + //NH-2401 + using (this.session.BeginTransaction()) + { + var firstDate = (from o in this.db.Orders where o.ShippingDate != null orderby o.ShippingDate select o.ShippingDate.ToString()).First(); + var firstDateConverted = (from o in this.db.Orders where o.ShippingDate.ToString() == firstDate.MappedAs(NHibernateUtil.String) select o.ShippingDate).FirstOrDefault(); + + Assert.AreEqual(DateTime.Parse(firstDate), firstDateConverted); + } + } + + + [Test] + public void CanSpecifyParameterTypeAndConvertToIntOnRestriction() + { + //NH-2401 + using (this.session.BeginTransaction()) + { + var firstIdAsInt = (from o in this.db.Orders where o.OrderDate != null orderby o.OrderDate select Convert.ToInt32(o.OrderDate)).First(); + var firstIdConverted = (from o in this.db.Orders where Convert.ToInt32(o.OrderDate.MappedAs(NHibernateUtil.Int32)) == firstIdAsInt orderby o.OrderDate select Convert.ToInt32(o.OrderDate)).Single(); + + Assert.AreEqual(firstIdAsInt, firstIdConverted); + } + } + + [Test] + public void CanSpecifyParameterTypeAndConvertToDecimalOnRestriction() + { + //NH-2401 + using (this.session.BeginTransaction()) + { + var firstIdAsDecimal = Convert.ToDecimal((from o in this.db.Orders orderby o.OrderId select o.OrderId).First()); + var firstIdConverted = (from o in this.db.Orders where o.OrderId.ToString() == firstIdAsDecimal.MappedAs(NHibernateUtil.Decimal).ToString() orderby o.OrderId select Convert.ToDecimal(o.OrderId)).Single(); + + Assert.AreEqual(firstIdAsDecimal, firstIdConverted); + } + } + + [Test] + public void CanSpecifyParameterTypeAndConvertToDoubleOnRestriction() + { + //NH-2401 + using (this.session.BeginTransaction()) + { + var firstIdAsDouble = Convert.ToDouble((from o in this.db.Orders orderby o.OrderId select o.OrderId).First()); + var firstIdConverted = (from o in this.db.Orders where Convert.ToDouble(o.OrderId.MappedAs(NHibernateUtil.Double)) == firstIdAsDouble orderby o.OrderId select Convert.ToDouble(o.OrderId)).Single(); + + Assert.AreEqual(firstIdAsDouble, firstIdConverted); + } + } } public class ParentChildBatch diff --git a/src/NHibernate/Linq/Functions/ConvertGenerator.cs b/src/NHibernate/Linq/Functions/ConvertGenerator.cs index c9e8e501f98..e82c56c9bf1 100644 --- a/src/NHibernate/Linq/Functions/ConvertGenerator.cs +++ b/src/NHibernate/Linq/Functions/ConvertGenerator.cs @@ -1,12 +1,9 @@ using System; -using System.Collections.Generic; -using System.Linq; +using System.Collections.ObjectModel; using System.Linq.Expressions; using System.Reflection; -using System.Text; using NHibernate.Hql.Ast; using NHibernate.Linq.Visitors; -using System.Collections.ObjectModel; namespace NHibernate.Linq.Functions { @@ -14,6 +11,12 @@ public abstract class ConvertToGenerator : BaseHqlGeneratorForMethod { public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { + var mce = targetObject as MethodCallExpression; + if (mce != null) + { + return treeBuilder.Cast(visitor.Visit(mce.Arguments[0]).AsExpression(), typeof(T)); + } + return treeBuilder.Cast(visitor.Visit(arguments[0]).AsExpression(), typeof(T)); } } diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 7baf31dc4a4..9d66248af94 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -52,6 +52,9 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new CollectionContainsGenerator()); this.Merge(new DateTimePropertiesHqlGenerator()); + + //NH-2401 + this.Merge(new MappedAsGenerator()); } protected bool GetRuntimeMethodGenerator(MethodInfo method, out IHqlGeneratorForMethod methodGenerator) diff --git a/src/NHibernate/Linq/Functions/MappedAsGenerator.cs b/src/NHibernate/Linq/Functions/MappedAsGenerator.cs new file mode 100644 index 00000000000..dd324caaa40 --- /dev/null +++ b/src/NHibernate/Linq/Functions/MappedAsGenerator.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Hql.Ast; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Type; + +namespace NHibernate.Linq.Functions +{ + public class MappedAsGenerator : BaseHqlGeneratorForMethod + { + public MappedAsGenerator() + { + SupportedMethods = new[] { ReflectionHelper.GetMethodDefinition(x => x.MappedAs(null)) }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + var result = visitor.Visit(arguments[0]).AsExpression(); + return result; + } + } +} diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 7ce62e5fc9b..89bdc77a689 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -281,6 +281,12 @@ public IEnumerable SupportedMethods public HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { + var mce = targetObject as MethodCallExpression; + if (mce != null) + { + return treeBuilder.MethodCall("str", visitor.Visit(mce.Arguments[0]).AsExpression()); + } + return treeBuilder.MethodCall("str", visitor.Visit(targetObject).AsExpression()); } } diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index e9ee33324f0..79c151d07d4 100755 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -10,6 +10,12 @@ namespace NHibernate.Linq { public static class LinqExtensionMethods { + public static T MappedAs(this T parameter, NHibernate.Type.IType type) + { + //NH-2401 + return parameter; + } + public static IQueryable Query(this ISession session) { return new NhQueryable(session.GetSessionImplementation()); diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index cb4c82c90f4..7bc611f029a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -17,6 +17,8 @@ namespace NHibernate.Linq.Visitors /// public class ExpressionParameterVisitor : ExpressionTreeVisitor { + //NH-2401 + private readonly Dictionary parameterTypeOverrides = new Dictionary(); private readonly Dictionary _parameters = new Dictionary(); private readonly ISessionFactoryImplementor _sessionFactory; @@ -60,6 +62,25 @@ protected override Expression VisitMethodCallExpression(MethodCallExpression exp return Expression.Call(null, expression.Method, query, arg); } + /*if (expression.Method == ReflectionHelper.GetMethod(x => x.ToString())) + { + //NH-2401: detect ToString after MappedAs + //we just leave thos to StringGenerator + return base.VisitMethodCallExpression(expression); + }*/ + + if ((expression.Method.DeclaringType == typeof(LinqExtensionMethods)) && (expression.Method.Name == "MappedAs")) + { + //NH-2401: detect MappedAs + //we cannot do this in a *Generator class because there we don't have access to the parameters collection (_parameters) + var typeExpression = Expression.Lambda>(Expression.Convert(expression.Arguments[1], typeof(IType))); + var type = typeExpression.Compile()(); + + this.parameterTypeOverrides[_parameters.Count] = type; + + return expression; + } + if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory)) { return expression; @@ -75,10 +96,14 @@ protected override Expression VisitConstantExpression(ConstantExpression express // We use null for the type to indicate that the caller should let HQL figure it out. IType type = null; - // We have a bit more information about the null parameter value. - // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) - if (expression.Value == null) - type = NHibernateUtil.GuessType(expression.Type); + //NH-2401 + if (this.parameterTypeOverrides.TryGetValue(this._parameters.Count, out type) == false) + { + // We have a bit more information about the null parameter value. + // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + if (expression.Value == null) + type = NHibernateUtil.GuessType(expression.Type); + } // There is more information available in the Linq expression than to HQL directly. // In some cases it might be advantageous to use the extra info. Assuming this diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index cbb25cad16d..3ec3a280bb6 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -293,6 +293,7 @@ +