diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index e5d239f7ed2..0507d20eed2 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -616,6 +616,64 @@ public async Task UsingParameterInEvaluatableExpressionAsync() await (db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToListAsync()); } + [Test] + public async Task UsingParameterWithImplicitOperatorAsync() + { + var id = new GuidImplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502CF}")); + Assert.That(await (db.Shippers.Where(o => o.Reference == id).ToListAsync()), Has.Count.EqualTo(1)); + + id = new GuidImplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502FF}")); + Assert.That(await (db.Shippers.Where(o => o.Reference == id).ToListAsync()), Is.Empty); + + await (AssertTotalParametersAsync( + db.Shippers.Where(o => o.Reference == id && id == o.Reference), + 1)); + } + + private struct GuidImplicitWrapper + { + public readonly Guid Id; + + public GuidImplicitWrapper(Guid id) + { + Id = id; + } + + public static implicit operator Guid(GuidImplicitWrapper idWrapper) + { + return idWrapper.Id; + } + } + + [Test] + public async Task UsingParameterWithExplicitOperatorAsync() + { + var id = new GuidExplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502CF}")); + Assert.That(await (db.Shippers.Where(o => o.Reference == (Guid) id).ToListAsync()), Has.Count.EqualTo(1)); + + id = new GuidExplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502FF}")); + Assert.That(await (db.Shippers.Where(o => o.Reference == (Guid) id).ToListAsync()), Is.Empty); + + await (AssertTotalParametersAsync( + db.Shippers.Where(o => o.Reference == (Guid) id && (Guid) id == o.Reference), + 1)); + } + + private struct GuidExplicitWrapper + { + public readonly Guid Id; + + public GuidExplicitWrapper(Guid id) + { + Id = id; + } + + public static explicit operator Guid(GuidExplicitWrapper idWrapper) + { + return idWrapper.Id; + } + } + [Test] public async Task UsingParameterOnSelectorsAsync() { diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index fdbb1a73275..e30e11ef12e 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -604,6 +604,64 @@ public void UsingParameterInEvaluatableExpression() db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToList(); } + [Test] + public void UsingParameterWithImplicitOperator() + { + var id = new GuidImplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502CF}")); + Assert.That(db.Shippers.Where(o => o.Reference == id).ToList(), Has.Count.EqualTo(1)); + + id = new GuidImplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502FF}")); + Assert.That(db.Shippers.Where(o => o.Reference == id).ToList(), Is.Empty); + + AssertTotalParameters( + db.Shippers.Where(o => o.Reference == id && id == o.Reference), + 1); + } + + private struct GuidImplicitWrapper + { + public readonly Guid Id; + + public GuidImplicitWrapper(Guid id) + { + Id = id; + } + + public static implicit operator Guid(GuidImplicitWrapper idWrapper) + { + return idWrapper.Id; + } + } + + [Test] + public void UsingParameterWithExplicitOperator() + { + var id = new GuidExplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502CF}")); + Assert.That(db.Shippers.Where(o => o.Reference == (Guid) id).ToList(), Has.Count.EqualTo(1)); + + id = new GuidExplicitWrapper(new Guid("{356E4A7E-B027-4321-BA40-E2677E6502FF}")); + Assert.That(db.Shippers.Where(o => o.Reference == (Guid) id).ToList(), Is.Empty); + + AssertTotalParameters( + db.Shippers.Where(o => o.Reference == (Guid) id && (Guid) id == o.Reference), + 1); + } + + private struct GuidExplicitWrapper + { + public readonly Guid Id; + + public GuidExplicitWrapper(Guid id) + { + Id = id; + } + + public static explicit operator Guid(GuidExplicitWrapper idWrapper) + { + return idWrapper.Id; + } + } + [Test] public void UsingParameterOnSelectors() { diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index a70109bd59c..c0658d4312c 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -135,49 +135,72 @@ protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) { - // We use null for the type to indicate that the caller should let HQL figure it out. - object value = expression.Value; - IType type = null; - - // We have a bit more information about the null parameter value. - // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) - // In v5.3 types are calculated by ParameterTypeLocator, this logic is only for back compatibility. - // TODO 6.0: Remove - if (expression.Value == null) - type = NHibernateUtil.GuessType(expression.Type); - - // Constant characters should be sent as strings - // TODO 6.0: Remove - if (_queryVariables == null && expression.Type == typeof(char)) - { - value = value.ToString(); - } - - // There is more information available in the Linq expression than to HQL directly. - // In some cases it might be advantageous to use the extra info. Assuming this - // comes up, it would be nice to combine the HQL parameter type determination code - // and the Expression information. - - NamedParameter parameter = null; - if (_queryVariables != null && - _queryVariables.TryGetValue(expression, out var variable) && - !_variableParameters.TryGetValue(variable, out parameter)) - { - parameter = CreateParameter(expression, value, type); - _variableParameters.Add(variable, parameter); - } + AddConstantExpressionParameter(expression, null); + } - if (parameter == null) - { - parameter = CreateParameter(expression, value, type); - } + return base.VisitConstant(expression); + } - _parameters.Add(expression, parameter); + protected override Expression VisitUnary(UnaryExpression node) + { + // If we have an expression like "Convert()" we do not want to lose the conversion operation + // because it might be necessary if the types are incompatible with each other, which might happen if + // the expression uses an implicitly or explicitly defined cast operator. + if (node.NodeType == ExpressionType.Convert && + node.Method != null && // The implicit/explicit operator method + node.Operand is ConstantExpression constantExpression) + { + // Instead of getting constantExpression.Value, we override the value by compiling and executing this subtree, + // performing the cast. + var lambda = Expression.Lambda>(Expression.Convert(node, typeof(object))); + var compiledLambda = lambda.Compile(); - return base.VisitConstant(expression); + AddConstantExpressionParameter(constantExpression, compiledLambda()); } - return base.VisitConstant(expression); + return base.VisitUnary(node); + } + + private void AddConstantExpressionParameter(ConstantExpression expression, object overrideValue) + { + // We use null for the type to indicate that the caller should let HQL figure it out. + object value = overrideValue ?? expression.Value; + IType type = null; + + // We have a bit more information about the null parameter value. + // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + // In v5.3 types are calculated by ParameterTypeLocator, this logic is only for back compatibility. + // TODO 6.0: Remove + if (value == null) + type = NHibernateUtil.GuessType(expression.Type); + + // Constant characters should be sent as strings + // TODO 6.0: Remove + if (_queryVariables == null && expression.Type == typeof(char)) + { + value = value.ToString(); + } + + // There is more information available in the Linq expression than to HQL directly. + // In some cases it might be advantageous to use the extra info. Assuming this + // comes up, it would be nice to combine the HQL parameter type determination code + // and the Expression information. + + NamedParameter parameter = null; + if (_queryVariables != null && + _queryVariables.TryGetValue(expression, out var variable) && + !_variableParameters.TryGetValue(variable, out parameter)) + { + parameter = CreateParameter(expression, value, type); + _variableParameters.Add(variable, parameter); + } + + if (parameter == null) + { + parameter = CreateParameter(expression, value, type); + } + + _parameters.Add(expression, parameter); } private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type)