diff --git a/src/NHibernate.Test/Async/Linq/OperatorTests.cs b/src/NHibernate.Test/Async/Linq/OperatorTests.cs index 2eb4ff3c74a..607f7f9ecb2 100644 --- a/src/NHibernate.Test/Async/Linq/OperatorTests.cs +++ b/src/NHibernate.Test/Async/Linq/OperatorTests.cs @@ -35,6 +35,20 @@ public async Task UnaryMinusAsync() Assert.AreEqual(1, await (session.Query().CountAsync(a => -a.NumberOfHours == -7))); } + [Test] + public async Task DecimalAddAsync() + { + decimal offset = 5.5m; + decimal test = 10248 + offset; + var result = await (session.Query().Where(e => offset + e.OrderId == test).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(1)); + + offset = 5.5m; + test = 32.38m + offset; + result = await (session.Query().Where(e => offset + e.Freight == test).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(1)); + } + [Test] public async Task UnaryPlusAsync() { diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 0507d20eed2..6cae0a4d2bd 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -320,7 +320,7 @@ public async Task CompareFloatingPointParametersAndColumnsAsync() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); })); } diff --git a/src/NHibernate.Test/Linq/OperatorTests.cs b/src/NHibernate.Test/Linq/OperatorTests.cs index e9cb7458258..3f0ae38a568 100644 --- a/src/NHibernate.Test/Linq/OperatorTests.cs +++ b/src/NHibernate.Test/Linq/OperatorTests.cs @@ -24,6 +24,20 @@ public void UnaryMinus() Assert.AreEqual(1, session.Query().Count(a => -a.NumberOfHours == -7)); } + [Test] + public void DecimalAdd() + { + decimal offset = 5.5m; + decimal test = 10248 + offset; + var result = session.Query().Where(e => offset + e.OrderId == test).ToList(); + Assert.That(result, Has.Count.EqualTo(1)); + + offset = 5.5m; + test = 32.38m + offset; + result = session.Query().Where(e => offset + e.Freight == test).ToList(); + Assert.That(result, Has.Count.EqualTo(1)); + } + [Test] public void UnaryPlus() { diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index e30e11ef12e..e86a596af80 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -308,7 +308,7 @@ public void CompareFloatingPointParametersAndColumns() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); }); } diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index f1d04a96640..894f6d2405b 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -17,6 +17,7 @@ using NHibernate.Dialect; using NHibernate.Driver; using NHibernate.Engine.Query; +using NHibernate.SqlTypes; using NHibernate.Util; using NSubstitute; @@ -525,6 +526,8 @@ protected void ClearQueryPlanCache() var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf()); var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType()) .Invoke(null, new object[] { new object[0] }); + substitute.GetCastTypeName(Arg.Any()) + .ReturnsForAnyArgs(x => origDialect.GetCastTypeName(x.ArgAt(0))); dialectProperty.SetValue(Sfi.Settings, substitute); diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 6153309c7ac..9107fb48fa1 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -2611,6 +2611,11 @@ public virtual bool SupportsSqlBatches get { return false; } } + /// + /// Whether is stored as a floating point number. + /// + public virtual bool IsDecimalStoredAsFloatingPointNumber => false; + public virtual bool IsKnownToken(string currentToken, string nextToken) { return false; diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index e464d2f5e15..0ec22942a8a 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -362,6 +362,9 @@ public override bool GenerateTablePrimaryKeyConstraintForIdentityColumn get { return false; } } + /// + public override bool IsDecimalStoredAsFloatingPointNumber => true; + public override string Qualify(string catalog, string schema, string table) { StringBuilder qualifiedName = new StringBuilder(); diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 63d9d3e2a1e..8366d42f6ff 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -579,8 +579,12 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) if (_parameters.ConstantToParameterMap.TryGetValue(expression, out namedParameter)) { _parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, false)); + var parameter = _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); - return _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); + // SQLite driver binds decimal parameters to text, which can cause unexpected results in arithmetic operations. + return _parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber && expression.Type.UnwrapIfNullable() == typeof(decimal) + ? _hqlTreeBuilder.TransparentCast(parameter, expression.Type) + : parameter; } return _hqlTreeBuilder.Constant(expression.Value);