From 9822e6d1d143fd7f6ea676a1f4d5d633d3caffa2 Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 2 Jun 2021 21:01:52 +0200 Subject: [PATCH 1/3] Fix decimal equality comparison for Sqlite --- src/NHibernate.Test/Async/Linq/OperatorTests.cs | 14 ++++++++++++++ src/NHibernate.Test/Linq/OperatorTests.cs | 14 ++++++++++++++ .../Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 5 ++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/Async/Linq/OperatorTests.cs b/src/NHibernate.Test/Async/Linq/OperatorTests.cs index 2eb4ff3c74a..607f7f9ecb2 100644 --- a/src/NHibernate.Test/Async/Linq/OperatorTests.cs +++ b/src/NHibernate.Test/Async/Linq/OperatorTests.cs @@ -35,6 +35,20 @@ public async Task UnaryMinusAsync() Assert.AreEqual(1, await (session.Query().CountAsync(a => -a.NumberOfHours == -7))); } + [Test] + public async Task DecimalAddAsync() + { + decimal offset = 5.5m; + decimal test = 10248 + offset; + var result = await (session.Query().Where(e => offset + e.OrderId == test).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(1)); + + offset = 5.5m; + test = 32.38m + offset; + result = await (session.Query().Where(e => offset + e.Freight == test).ToListAsync()); + Assert.That(result, Has.Count.EqualTo(1)); + } + [Test] public async Task UnaryPlusAsync() { diff --git a/src/NHibernate.Test/Linq/OperatorTests.cs b/src/NHibernate.Test/Linq/OperatorTests.cs index e9cb7458258..3f0ae38a568 100644 --- a/src/NHibernate.Test/Linq/OperatorTests.cs +++ b/src/NHibernate.Test/Linq/OperatorTests.cs @@ -24,6 +24,20 @@ public void UnaryMinus() Assert.AreEqual(1, session.Query().Count(a => -a.NumberOfHours == -7)); } + [Test] + public void DecimalAdd() + { + decimal offset = 5.5m; + decimal test = 10248 + offset; + var result = session.Query().Where(e => offset + e.OrderId == test).ToList(); + Assert.That(result, Has.Count.EqualTo(1)); + + offset = 5.5m; + test = 32.38m + offset; + result = session.Query().Where(e => offset + e.Freight == test).ToList(); + Assert.That(result, Has.Count.EqualTo(1)); + } + [Test] public void UnaryPlus() { diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 63d9d3e2a1e..22594b18768 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -579,8 +579,11 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) if (_parameters.ConstantToParameterMap.TryGetValue(expression, out namedParameter)) { _parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, false)); + var parameter = _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); - return _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); + return HqlIdent.SupportsType(expression.Type) + ? _hqlTreeBuilder.TransparentCast(parameter, expression.Type) + : parameter; } return _hqlTreeBuilder.Constant(expression.Value); From 8635d6dbb19e9aec83e579e188bc696d3846d601 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 14 Jun 2021 00:13:14 +0200 Subject: [PATCH 2/3] Fix tests --- src/NHibernate.Test/Async/Linq/ParameterTests.cs | 2 +- src/NHibernate.Test/Linq/ParameterTests.cs | 2 +- src/NHibernate.Test/TestCase.cs | 3 +++ src/NHibernate/Dialect/Dialect.cs | 5 +++++ src/NHibernate/Dialect/SQLiteDialect.cs | 3 +++ .../Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 3 ++- 6 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 0507d20eed2..461292c2ea8 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -320,7 +320,7 @@ public async Task CompareFloatingPointParametersAndColumnsAsync() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect is SQLiteDialect ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); })); } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index e30e11ef12e..6233efe9730 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -308,7 +308,7 @@ public void CompareFloatingPointParametersAndColumns() totalParameters, sql => { - Assert.That(sql, Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect is SQLiteDialect ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); }); } diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index f1d04a96640..894f6d2405b 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -17,6 +17,7 @@ using NHibernate.Dialect; using NHibernate.Driver; using NHibernate.Engine.Query; +using NHibernate.SqlTypes; using NHibernate.Util; using NSubstitute; @@ -525,6 +526,8 @@ protected void ClearQueryPlanCache() var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf()); var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType()) .Invoke(null, new object[] { new object[0] }); + substitute.GetCastTypeName(Arg.Any()) + .ReturnsForAnyArgs(x => origDialect.GetCastTypeName(x.ArgAt(0))); dialectProperty.SetValue(Sfi.Settings, substitute); diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 6153309c7ac..44ec93d2089 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -2611,6 +2611,11 @@ public virtual bool SupportsSqlBatches get { return false; } } + /// + /// Whether is stored as a floating point number. + /// + internal virtual bool IsDecimalStoredAsFloatingPointNumber => false; + public virtual bool IsKnownToken(string currentToken, string nextToken) { return false; diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index e464d2f5e15..0084d8d6df9 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -362,6 +362,9 @@ public override bool GenerateTablePrimaryKeyConstraintForIdentityColumn get { return false; } } + /// + internal override bool IsDecimalStoredAsFloatingPointNumber => true; + public override string Qualify(string catalog, string schema, string table) { StringBuilder qualifiedName = new StringBuilder(); diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 22594b18768..2b31520bf96 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -581,7 +581,8 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) _parameters.RequiredHqlParameters.Add(new NamedParameterDescriptor(namedParameter.Name, null, false)); var parameter = _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); - return HqlIdent.SupportsType(expression.Type) + // SQLite driver binds decimal parameters to text, which can cause unexpected results in arithmetic operations. + return expression.Type.UnwrapIfNullable() == typeof(decimal) && _parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber ? _hqlTreeBuilder.TransparentCast(parameter, expression.Type) : parameter; } From bef484acda1db3240116d48b60f00fd4d7b267a9 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 14 Jun 2021 23:17:44 +0200 Subject: [PATCH 3/3] Code review changes --- src/NHibernate.Test/Async/Linq/ParameterTests.cs | 2 +- src/NHibernate.Test/Linq/ParameterTests.cs | 2 +- src/NHibernate/Dialect/Dialect.cs | 2 +- src/NHibernate/Dialect/SQLiteDialect.cs | 2 +- src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 461292c2ea8..6cae0a4d2bd 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -320,7 +320,7 @@ public async Task CompareFloatingPointParametersAndColumnsAsync() totalParameters, sql => { - Assert.That(sql, pair.Value == "Decimal" && Dialect is SQLiteDialect ? Does.Contain("cast") : Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); })); } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 6233efe9730..e86a596af80 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -308,7 +308,7 @@ public void CompareFloatingPointParametersAndColumns() totalParameters, sql => { - Assert.That(sql, pair.Value == "Decimal" && Dialect is SQLiteDialect ? Does.Contain("cast") : Does.Not.Contain("cast")); + Assert.That(sql, pair.Value == "Decimal" && Dialect.IsDecimalStoredAsFloatingPointNumber ? Does.Contain("cast") : Does.Not.Contain("cast")); Assert.That(GetTotalOccurrences(sql, $"Type: {pair.Value}"), Is.EqualTo(totalParameters)); }); } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 44ec93d2089..9107fb48fa1 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -2614,7 +2614,7 @@ public virtual bool SupportsSqlBatches /// /// Whether is stored as a floating point number. /// - internal virtual bool IsDecimalStoredAsFloatingPointNumber => false; + public virtual bool IsDecimalStoredAsFloatingPointNumber => false; public virtual bool IsKnownToken(string currentToken, string nextToken) { diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index 0084d8d6df9..0ec22942a8a 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -363,7 +363,7 @@ public override bool GenerateTablePrimaryKeyConstraintForIdentityColumn } /// - internal override bool IsDecimalStoredAsFloatingPointNumber => true; + public override bool IsDecimalStoredAsFloatingPointNumber => true; public override string Qualify(string catalog, string schema, string table) { diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 2b31520bf96..8366d42f6ff 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -582,7 +582,7 @@ protected HqlTreeNode VisitConstantExpression(ConstantExpression expression) var parameter = _hqlTreeBuilder.Parameter(namedParameter.Name).AsExpression(); // SQLite driver binds decimal parameters to text, which can cause unexpected results in arithmetic operations. - return expression.Type.UnwrapIfNullable() == typeof(decimal) && _parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber + return _parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber && expression.Type.UnwrapIfNullable() == typeof(decimal) ? _hqlTreeBuilder.TransparentCast(parameter, expression.Type) : parameter; }