Skip to content

Commit 86f6cd0

Browse files
committed
Reduce cast usage for COUNT aggregate and add support for Mssql count_big
1 parent f8dd4ee commit 86f6cd0

File tree

17 files changed

+211
-16
lines changed

17 files changed

+211
-16
lines changed

src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System;
1212
using System.Linq;
1313
using NHibernate.Cfg;
14+
using NHibernate.Dialect;
1415
using NUnit.Framework;
1516
using NHibernate.Linq;
1617

@@ -110,5 +111,50 @@ into temp
110111

111112
Assert.That(result.Count, Is.EqualTo(77));
112113
}
114+
115+
[Test]
116+
public async Task CheckSqlFunctionNameLongCountAsync()
117+
{
118+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
119+
using (var sqlLog = new SqlLogSpy())
120+
{
121+
var result = await (db.Orders.LongCountAsync());
122+
Assert.That(result, Is.EqualTo(830));
123+
124+
var log = sqlLog.GetWholeLog();
125+
Assert.That(log, Does.Contain($"{name}("));
126+
}
127+
}
128+
129+
[Test]
130+
public async Task CheckSqlFunctionNameForCountAsync()
131+
{
132+
using (var sqlLog = new SqlLogSpy())
133+
{
134+
var result = await (db.Orders.CountAsync());
135+
Assert.That(result, Is.EqualTo(830));
136+
137+
var log = sqlLog.GetWholeLog();
138+
Assert.That(log, Does.Contain("count("));
139+
}
140+
}
141+
142+
[Test]
143+
public async Task CheckMssqlCountCastAsync()
144+
{
145+
if (!(Dialect is MsSql2000Dialect))
146+
{
147+
Assert.Ignore();
148+
}
149+
150+
using (var sqlLog = new SqlLogSpy())
151+
{
152+
var result = await (db.Orders.CountAsync());
153+
Assert.That(result, Is.EqualTo(830));
154+
155+
var log = sqlLog.GetWholeLog();
156+
Assert.That(log, Does.Not.Contain("cast("));
157+
}
158+
}
113159
}
114160
}

src/NHibernate.Test/Async/QueryTest/CountFixture.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
using NHibernate.Cfg;
1313
using NHibernate.Dialect.Function;
1414
using NHibernate.DomainModel;
15+
using NHibernate.Engine;
16+
using NHibernate.Type;
1517
using NUnit.Framework;
1618
using Environment=NHibernate.Cfg.Environment;
1719

@@ -55,4 +57,4 @@ public async Task OverriddenAsync()
5557
await (sf.CloseAsync());
5658
}
5759
}
58-
}
60+
}

src/NHibernate.Test/Hql/SimpleFunctionsTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ public void ClassicSum()
167167
}
168168
}
169169

170-
[Test]
170+
// Since v5.3
171+
[Test, Obsolete]
171172
public void ClassicCount()
172173
{
173174
//ANSI-SQL92 definition

src/NHibernate.Test/Linq/ByMethod/CountTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using NHibernate.Cfg;
4+
using NHibernate.Dialect;
45
using NUnit.Framework;
56

67
namespace NHibernate.Test.Linq.ByMethod
@@ -98,5 +99,50 @@ into temp
9899

99100
Assert.That(result.Count, Is.EqualTo(77));
100101
}
102+
103+
[Test]
104+
public void CheckSqlFunctionNameLongCount()
105+
{
106+
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
107+
using (var sqlLog = new SqlLogSpy())
108+
{
109+
var result = db.Orders.LongCount();
110+
Assert.That(result, Is.EqualTo(830));
111+
112+
var log = sqlLog.GetWholeLog();
113+
Assert.That(log, Does.Contain($"{name}("));
114+
}
115+
}
116+
117+
[Test]
118+
public void CheckSqlFunctionNameForCount()
119+
{
120+
using (var sqlLog = new SqlLogSpy())
121+
{
122+
var result = db.Orders.Count();
123+
Assert.That(result, Is.EqualTo(830));
124+
125+
var log = sqlLog.GetWholeLog();
126+
Assert.That(log, Does.Contain("count("));
127+
}
128+
}
129+
130+
[Test]
131+
public void CheckMssqlCountCast()
132+
{
133+
if (!(Dialect is MsSql2000Dialect))
134+
{
135+
Assert.Ignore();
136+
}
137+
138+
using (var sqlLog = new SqlLogSpy())
139+
{
140+
var result = db.Orders.Count();
141+
Assert.That(result, Is.EqualTo(830));
142+
143+
var log = sqlLog.GetWholeLog();
144+
Assert.That(log, Does.Not.Contain("cast("));
145+
}
146+
}
101147
}
102148
}

src/NHibernate.Test/QueryTest/CountFixture.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using NHibernate.Cfg;
33
using NHibernate.Dialect.Function;
44
using NHibernate.DomainModel;
5+
using NHibernate.Engine;
6+
using NHibernate.Type;
57
using NUnit.Framework;
68
using Environment=NHibernate.Cfg.Environment;
79

@@ -44,4 +46,17 @@ public void Overridden()
4446
sf.Close();
4547
}
4648
}
47-
}
49+
50+
[Serializable]
51+
internal class ClassicCountFunction : ClassicAggregateFunction
52+
{
53+
public ClassicCountFunction() : base("count", true)
54+
{
55+
}
56+
57+
public override IType ReturnType(IType columnType, IMapping mapping)
58+
{
59+
return NHibernateUtil.Int32;
60+
}
61+
}
62+
}

src/NHibernate/Dialect/Dialect.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public abstract partial class Dialect
5555
static Dialect()
5656
{
5757
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
58+
StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
5859
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
5960
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
6061
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);

src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace NHibernate.Dialect.Function
1010
{
1111
[Serializable]
12-
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
12+
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLAggregateFunction
1313
{
1414
private IType returnType = null;
1515
private readonly string name;
@@ -111,5 +111,18 @@ bool IFunctionGrammar.IsKnownArgument(string token)
111111
}
112112

113113
#endregion
114+
115+
#region ISQLAggregateFunction Members
116+
117+
/// <inheritdoc />
118+
public string FunctionName => name;
119+
120+
/// <inheritdoc />
121+
public virtual IType GetActualReturnType(IType argumentType, IMapping mapping)
122+
{
123+
return ReturnType(argumentType, mapping);
124+
}
125+
126+
#endregion
114127
}
115128
}

src/NHibernate/Dialect/Function/ClassicCountFunction.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace NHibernate.Dialect.Function
77
/// <summary>
88
/// Classic COUNT sqlfunction that return types as it was done in Hibernate 3.1
99
/// </summary>
10+
// Since v5.3
11+
[Obsolete("This class has no more usages in NHibernate and will be removed in a future version.")]
1012
[Serializable]
1113
public class ClassicCountFunction : ClassicAggregateFunction
1214
{
@@ -19,4 +21,4 @@ public override IType ReturnType(IType columnType, IMapping mapping)
1921
return NHibernateUtil.Int32;
2022
}
2123
}
22-
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using NHibernate.Engine;
2+
using NHibernate.Type;
3+
4+
namespace NHibernate.Dialect.Function
5+
{
6+
/// <inheritdoc />
7+
internal interface ISQLAggregateFunction : ISQLFunction
8+
{
9+
/// <summary>
10+
/// The name of the aggregate function.
11+
/// </summary>
12+
string FunctionName { get; }
13+
14+
/// <summary>
15+
/// Get the type that will be effectively returned by the underlying database.
16+
/// </summary>
17+
/// <param name="argumentType">The type of the first argument</param>
18+
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
19+
/// <returns></returns>
20+
IType GetActualReturnType(IType argumentType, IMapping mapping);
21+
}
22+
}

src/NHibernate/Dialect/MsSql2000Dialect.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()
286286

287287
protected virtual void RegisterFunctions()
288288
{
289-
RegisterFunction("count", new CountBigQueryFunction());
289+
RegisterFunction("count", new CountQueryFunction());
290+
RegisterFunction("count_big", new CountBigQueryFunction());
290291

291292
RegisterFunction("abs", new StandardSQLFunction("abs"));
292293
RegisterFunction("absval", new StandardSQLFunction("absval"));
@@ -705,6 +706,16 @@ public override IType ReturnType(IType columnType, IMapping mapping)
705706
}
706707
}
707708

709+
[Serializable]
710+
private class CountQueryFunction : CountQueryFunctionInfo
711+
{
712+
/// <inheritdoc />
713+
public override IType GetActualReturnType(IType columnType, IMapping mapping)
714+
{
715+
return NHibernateUtil.Int32;
716+
}
717+
}
718+
708719
public override bool SupportsCircularCascadeDeleteConstraints
709720
{
710721
get

0 commit comments

Comments
 (0)